String changes (#3)

* initial commit

* additional data types & tensor type

Signed-off-by: raver119 <raver119@gmail.com>

* next step

Signed-off-by: raver119 <raver119@gmail.com>

* missing include

* sparse_to_dense

Signed-off-by: raver119 <raver119@gmail.com>

* few more tests files

Signed-off-by: raver119 <raver119@gmail.com>

* draft

Signed-off-by: raver119 <raver119@gmail.com>

* numeric sparse_to_dense

Signed-off-by: raver119 <raver119@gmail.com>

* comment

Signed-off-by: raver119 <raver119@gmail.com>

* string sparse_to_dense version

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA DataBuffer expand

Signed-off-by: raver119 <raver119@gmail.com>

* few tweaks for CUDA build

Signed-off-by: raver119 <raver119@gmail.com>

* shape fn for string_split

Signed-off-by: raver119 <raver119@gmail.com>

* one more comment

Signed-off-by: raver119 <raver119@gmail.com>

* string_split indices

Signed-off-by: raver119 <raver119@gmail.com>

* next step

Signed-off-by: raver119 <raver119@gmail.com>

* test passes

Signed-off-by: raver119 <raver119@gmail.com>

* few rearrangements for databuffer implementations

Signed-off-by: raver119 <raver119@gmail.com>

* DataBuffer: move inline methods to common implementations

Signed-off-by: raver119 <raver119@gmail.com>

* add native DataBuffer to Nd4j presets

Signed-off-by: raver119 <raver119@gmail.com>

* DataBuffer creation

Signed-off-by: raver119 <raver119@gmail.com>

* use DataBuffer for allocation

Signed-off-by: raver119 <raver119@gmail.com>

* cpu databuffer as deallocatable

Signed-off-by: raver119 <raver119@gmail.com>

* DataBuffer setters for bufers

Signed-off-by: raver119 <raver119@gmail.com>

* couple of wrappers

Signed-off-by: raver119 <raver119@gmail.com>

* DataBuffers being passed around

Signed-off-by: raver119 <raver119@gmail.com>

* Bunch of ByteBuffer-related signatures gone

Signed-off-by: raver119 <raver119@gmail.com>

* - few more Nd4j signatures removed
- minor fix for bfloat16

Signed-off-by: raver119 <raver119@gmail.com>

* nullptr pointer is still a pointer, but 0 as address :)

Signed-off-by: raver119 <raver119@gmail.com>

* one special test

Signed-off-by: raver119 <raver119@gmail.com>

* empty string array init

Signed-off-by: raver119 <raver119@gmail.com>

* one more test in cpp

Signed-off-by: raver119 <raver119@gmail.com>

* memcpy instead of databuffer swap

Signed-off-by: raver119 <raver119@gmail.com>

* special InteropDataBuffer for front-end languages

Signed-off-by: raver119 <raver119@gmail.com>

* few tweaks for java

Signed-off-by: raver119 <raver119@gmail.com>

* pointer/indexer actualization

Signed-off-by: raver119 <raver119@gmail.com>

* CustomOp returns list for inputArumgents and outputArguments instead of array

Signed-off-by: raver119 <raver119@gmail.com>

* redundant call

Signed-off-by: raver119 <raver119@gmail.com>

* print_variable op

Signed-off-by: raver119 <raver119@gmail.com>

* - view handling (but wrong one)
- print_variable java wrapper

Signed-off-by: raver119 <raver119@gmail.com>

* one more test

Signed-off-by: raver119 <raver119@gmail.com>

* - empty arrays handling

Signed-off-by: raver119 <raver119@gmail.com>

* - deserialization works now

Signed-off-by: raver119 <raver119@gmail.com>

* minor fix

Signed-off-by: raver119 <raver119@gmail.com>

* meh

Signed-off-by: raver119 <raver119@gmail.com>

* one more fix

Signed-off-by: raver119 <raver119@gmail.com>

* initial cuda commit

Signed-off-by: raver119 <raver119@gmail.com>

* print_variable message validation

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA views

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA special buffer size

Signed-off-by: raver119 <raver119@gmail.com>

* minor update to match master changes

Signed-off-by: raver119 <raver119@gmail.com>

* - consider arrays always actual on device for CUDA
- additional PrintVariable constructor
- CudaUtf8Buffer now allocates host buffer by default

Signed-off-by: raver119 <raver119@gmail.com>

* meh

Signed-off-by: raver119 <raver119@gmail.com>

* - print_variable now allows print from device

Signed-off-by: raver119 <raver119@gmail.com>

* InteropDataBuffer data type fix

Signed-off-by: raver119 <raver119@gmail.com>

* ...

Signed-off-by: raver119 <raver119@gmail.com>

* disable some debug messages

Signed-off-by: raver119 <raver119@gmail.com>

* master pulled in

Signed-off-by: raver119 <raver119@gmail.com>

* couple of new methods for DataBuffer interop

Signed-off-by: raver119 <raver119@gmail.com>

* java side

Signed-off-by: raver119 <raver119@gmail.com>

* offsetted constructor

Signed-off-by: raver119 <raver119@gmail.com>

* new CUDA deallocator

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA backend torn apart

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA backend torn apart 2

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA backend torn apart 3

Signed-off-by: raver119 <raver119@gmail.com>

* - few new tests
- few new methods for DataBuffer management

Signed-off-by: raver119 <raver119@gmail.com>

* few more tests + few more tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* two failing tests

Signed-off-by: raver119 <raver119@gmail.com>

* one more test

Signed-off-by: raver119 <raver119@gmail.com>

* two failing tests pass

Signed-off-by: raver119 <raver119@gmail.com>

* now we pass DataBuffer to legacy ops too

Signed-off-by: raver119 <raver119@gmail.com>

* Native DataBuffer for legacy ops, Java side

Signed-off-by: raver119 <raver119@gmail.com>

* CPU java side update

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA java side update

Signed-off-by: raver119 <raver119@gmail.com>

* no more prepare/register action on java side

Signed-off-by: raver119 <raver119@gmail.com>

* NDArray::prepare/register use now accepts vectors

Signed-off-by: raver119 <raver119@gmail.com>

* InteropDataBuffer now has few more convenience methods

Signed-off-by: raver119 <raver119@gmail.com>

* java bindings update

Signed-off-by: raver119 <raver119@gmail.com>

* tick device in NativeOps

Signed-off-by: raver119 <raver119@gmail.com>

* Corrected usage of OpaqueBuffer for tests.

* Corrected usage of OpaqueBuffer for java tests.

* NativeOpsTests fixes.

* print_variable now returns scalar

Signed-off-by: raver119 <raver119@gmail.com>

* one more test

Signed-off-by: raver119 <raver119@gmail.com>

* compat_string_split fix for CUDA

Signed-off-by: raver119 <raver119@gmail.com>

* - CUDA execScalar fix
- CUDA lazyAllocateHostPointer now checks java indexer/pointer instead of native pointer

Signed-off-by: raver119 <raver119@gmail.com>

* legacy ops DataBuffer migration prototype

Signed-off-by: raver119 <raver119@gmail.com>

* ignore device shapeinfo coming from java

Signed-off-by: raver119 <raver119@gmail.com>

* minor fix

Signed-off-by: raver119 <raver119@gmail.com>

* minor transformAny fix

Signed-off-by: raver119 <raver119@gmail.com>

* minor tweak for lazy host allocation

Signed-off-by: raver119 <raver119@gmail.com>

* - DataBuffer::memcpy method
- bitcast now uses memcpy

Signed-off-by: raver119 <raver119@gmail.com>

* - IndexReduce CUDA dimension buffer fix

Signed-off-by: raver119 <raver119@gmail.com>

* views for CPU and CUDA

Signed-off-by: raver119 <raver119@gmail.com>

* less spam

Signed-off-by: raver119 <raver119@gmail.com>

* optional memory init

Signed-off-by: raver119 <raver119@gmail.com>

* async memset

Signed-off-by: raver119 <raver119@gmail.com>

* - SummaryStats CUDA fix
- DataBuffer.sameUnderlyingData() impl
- execBroadcast fix

Signed-off-by: raver119 <raver119@gmail.com>

* - reduce3All fix
switch to CUDA 10 temporarily

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA version

Signed-off-by: raver119 <raver119@gmail.com>

* proper memory deallocator registration

Signed-off-by: raver119 <raver119@gmail.com>

* HOST_ONLY workspace allocation

Signed-off-by: raver119 <raver119@gmail.com>

* temp commit

Signed-off-by: raver119 <raver119@gmail.com>

* few conflicts resolved

Signed-off-by: raver119 <raver119@gmail.com>

* few minor fixes

Signed-off-by: raver119 <raver119@gmail.com>

* one more minor fix

Signed-off-by: raver119 <raver119@gmail.com>

* NDArray permute should operate on JVM primitives

Signed-off-by: raver119 <raver119@gmail.com>

* - create InteropDataBuffer for shapes as well
- update pointers after view creation in Java

Signed-off-by: raver119 <raver119@gmail.com>

* - addressPointer temporary moved to C++

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA: don't account offset twice

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA: DataBuffer pointer constructor updated

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA NDArray.unsafeDuplication() simplified

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA minor workspace-related fixes

Signed-off-by: raver119 <raver119@gmail.com>

* CPU DataBuffer.reallocate()

Signed-off-by: raver119 <raver119@gmail.com>

* print_affinity op

Signed-off-by: raver119 <raver119@gmail.com>

* print_affinity java side

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA more tweaks for data locality

Signed-off-by: raver119 <raver119@gmail.com>

* - compat_string_split tweak
- CudaUtf8Buffer update

Signed-off-by: raver119 <raver119@gmail.com>

* INDArray.close() mechanic restored

Signed-off-by: raver119 <raver119@gmail.com>

* one more test fixed

Signed-off-by: raver119 <raver119@gmail.com>

* - CUDA DataBuffer.reallocate() updated
- cudaMemcpy (synchronous) restored

Signed-off-by: raver119 <raver119@gmail.com>

* one last fix

Signed-off-by: raver119 <raver119@gmail.com>

* bad import removed

Signed-off-by: raver119 <raver119@gmail.com>

* another small fix

Signed-off-by: raver119 <raver119@gmail.com>

* one special test

Signed-off-by: raver119 <raver119@gmail.com>

* fix bad databuffer size

Signed-off-by: raver119 <raver119@gmail.com>

* release primaryBuffer on replace

Signed-off-by: raver119 <raver119@gmail.com>

* higher timeout

Signed-off-by: raver119 <raver119@gmail.com>

* disable timeouts

Signed-off-by: raver119 <raver119@gmail.com>

* dbCreateView now validates offset and length of a view

Signed-off-by: raver119 <raver119@gmail.com>

* additional validation for dbExpand

Signed-off-by: raver119 <raver119@gmail.com>

* restore timeout back again

Signed-off-by: raver119 <raver119@gmail.com>

* smaller distribution for rng test to prevent timeouts

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA DataBuffer::memcpy now copies to device all the time

Signed-off-by: raver119 <raver119@gmail.com>

* OpaqueDataBuffer now contains all required methods for interop

Signed-off-by: raver119 <raver119@gmail.com>

* some javadoc

Signed-off-by: raver119 <raver119@gmail.com>

* GC on failed allocations

Signed-off-by: raver119 <raver119@gmail.com>

* minoe memcpu tweak

Signed-off-by: raver119 <raver119@gmail.com>

* one more bitcast test

Signed-off-by: raver119 <raver119@gmail.com>

* - NDArray::deviceId() propagation
- special multi-threaded test for data locality checks

Signed-off-by: raver119 <raver119@gmail.com>

* DataBuffer additional syncStream

Signed-off-by: raver119 <raver119@gmail.com>

* DataBuffer additional syncStream

Signed-off-by: raver119 <raver119@gmail.com>

* one ignored test

Signed-off-by: raver119 <raver119@gmail.com>

* skip host alloc for empty arrays

Signed-off-by: raver119 <raver119@gmail.com>

* ByteBuffer support is back

Signed-off-by: raver119 <raver119@gmail.com>

* DataBuffer::memcpy minor fix

Signed-off-by: raver119 <raver119@gmail.com>

* few minor prelu/bp tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* nullify-related fixes

Signed-off-by: raver119 <raver119@gmail.com>

* PReLU fixes (#157)

Signed-off-by: Alex Black <blacka101@gmail.com>

* Build fixed

* Fix tests

* one more ByteBuffer signature restored

Signed-off-by: raver119 <raver119@gmail.com>

* nd4j-jdbc-hsql profiles fix

Signed-off-by: raver119 <raver119@gmail.com>

* nd4j-jdbc-hsql profiles fix

Signed-off-by: raver119 <raver119@gmail.com>

* PReLU weight init fix

Signed-off-by: Alex Black <blacka101@gmail.com>

* Small PReLU fix

Signed-off-by: Alex Black <blacka101@gmail.com>

* - INDArray.migrate() reactivated
- DataBuffer::setDeviceId(...) added
- InteropDataBuffer Z syncToDevice added for views

Signed-off-by: raver119 <raver119@gmail.com>

* missed file

Signed-off-by: raver119 <raver119@gmail.com>

* Small tweak

Signed-off-by: Alex Black <blacka101@gmail.com>

* cuda 10.2

Signed-off-by: raver119 <raver119@gmail.com>

* minor fix

Signed-off-by: raver119 <raver119@gmail.com>

Co-authored-by: shugeo <sgazeos@gmail.com>
Co-authored-by: Alex Black <blacka101@gmail.com>
Co-authored-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
master
raver119 2020-01-04 13:27:50 +03:00 committed by GitHub
parent 451d9d57fd
commit 29e8e09db6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
169 changed files with 8463 additions and 7839 deletions

View File

@ -121,6 +121,7 @@ public class PReLULayer extends BaseLayer {
public static class Builder extends FeedForwardLayer.Builder<PReLULayer.Builder> { public static class Builder extends FeedForwardLayer.Builder<PReLULayer.Builder> {
public Builder(){ public Builder(){
//Default to 0s, and don't inherit global default
this.weightInitFn = new WeightInitConstant(0); this.weightInitFn = new WeightInitConstant(0);
} }

View File

@ -20,7 +20,7 @@ import lombok.Getter;
import lombok.NonNull; import lombok.NonNull;
import org.deeplearning4j.models.word2vec.VocabWord; import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache; import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.linalg.api.buffer.FloatBuffer; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -63,7 +63,7 @@ public class NegativeHolder implements Serializable {
protected void makeTable(int tableSize, double power) { protected void makeTable(int tableSize, double power) {
int vocabSize = vocab.numWords(); int vocabSize = vocab.numWords();
table = Nd4j.create(new FloatBuffer(tableSize)); table = Nd4j.create(DataType.FLOAT, tableSize);
double trainWordsPow = 0.0; double trainWordsPow = 0.0;
for (String word : vocab.words()) { for (String word : vocab.words()) {
trainWordsPow += Math.pow(vocab.wordFrequency(word), power); trainWordsPow += Math.pow(vocab.wordFrequency(word), power);

View File

@ -42,6 +42,8 @@
#include <helpers/ConstantShapeHelper.h> #include <helpers/ConstantShapeHelper.h>
#include <array/DataBuffer.h> #include <array/DataBuffer.h>
#include <execution/AffinityManager.h> #include <execution/AffinityManager.h>
#include <memory>
#include <array/InteropDataBuffer.h>
namespace nd4j { namespace nd4j {
@ -301,14 +303,11 @@ namespace nd4j {
* @param writeList * @param writeList
* @param readList * @param readList
*/ */
// TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list static void registerSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList);
static void registerSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList); static void prepareSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables = false);
static void prepareSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables = false);
// TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list
static void registerPrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList);
static void preparePrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables = false);
static void registerPrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList);
static void preparePrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables = false);
/** /**
* This method returns buffer pointer offset by given number of elements, wrt own data type * This method returns buffer pointer offset by given number of elements, wrt own data type

View File

@ -223,6 +223,8 @@ NDArray::NDArray(std::shared_ptr<DataBuffer> buffer, const ShapeDescriptor& desc
setShapeInfo(descriptor); setShapeInfo(descriptor);
_buffer = buffer; _buffer = buffer;
_isView = offset > 0 || _length * DataTypeUtils::sizeOf(_dataType) < buffer->getLenInBytes();
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -288,6 +290,8 @@ NDArray::NDArray(std::shared_ptr<DataBuffer> buffer, const char order, const std
setShapeInfo(ShapeDescriptor(buffer->getDataType(), order, shape)); setShapeInfo(ShapeDescriptor(buffer->getDataType(), order, shape));
_buffer = buffer; _buffer = buffer;
_isView = _length * DataTypeUtils::sizeOf(_dataType) < buffer->getLenInBytes();
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////

View File

@ -68,6 +68,7 @@ bool verbose = false;
#include <array/ConstantDescriptor.h> #include <array/ConstantDescriptor.h>
#include <helpers/ConstantShapeHelper.h> #include <helpers/ConstantShapeHelper.h>
#include <array/ConstantDataBuffer.h> #include <array/ConstantDataBuffer.h>
#include <array/InteropDataBuffer.h>
#include <helpers/ConstantHelper.h> #include <helpers/ConstantHelper.h>
#include <array/TadPack.h> #include <array/TadPack.h>
#include <graph/VariablesSet.h> #include <graph/VariablesSet.h>
@ -76,6 +77,8 @@ bool verbose = false;
#include <graph/ResultWrapper.h> #include <graph/ResultWrapper.h>
#include <DebugInfo.h> #include <DebugInfo.h>
typedef nd4j::InteropDataBuffer OpaqueDataBuffer;
extern "C" { extern "C" {
/** /**
@ -118,11 +121,9 @@ ND4J_EXPORT void setTADThreshold(int num);
*/ */
ND4J_EXPORT void execIndexReduceScalar(Nd4jPointer *extraPointers, ND4J_EXPORT void execIndexReduceScalar(Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo);
void *dZ, Nd4jLong *dZShapeInfo);
/** /**
* *
@ -137,13 +138,10 @@ ND4J_EXPORT void execIndexReduceScalar(Nd4jPointer *extraPointers,
*/ */
ND4J_EXPORT void execIndexReduce(Nd4jPointer *extraPointers, ND4J_EXPORT void execIndexReduce(Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape);
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape);
/** /**
* *
@ -160,28 +158,20 @@ ND4J_EXPORT void execIndexReduce(Nd4jPointer *extraPointers,
ND4J_EXPORT void execBroadcast( ND4J_EXPORT void execBroadcast(
Nd4jPointer *extraPointers, Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
void *hY, Nd4jLong *hYShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape);
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape);
ND4J_EXPORT void execBroadcastBool( ND4J_EXPORT void execBroadcastBool(
Nd4jPointer *extraPointers, Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
void *hY, Nd4jLong *hYShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams, void *extraParams,
void *hDimension, Nd4jLong *hDimensionShape, OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape);
void *dDimension, Nd4jLong *dDimensionShape);
/** /**
* *
@ -198,23 +188,17 @@ ND4J_EXPORT void execBroadcastBool(
ND4J_EXPORT void execPairwiseTransform( ND4J_EXPORT void execPairwiseTransform(
Nd4jPointer *extraPointers, Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
void *hY, Nd4jLong *hYShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams); void *extraParams);
ND4J_EXPORT void execPairwiseTransformBool( ND4J_EXPORT void execPairwiseTransformBool(
Nd4jPointer *extraPointers, Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
void *hY, Nd4jLong *hYShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams); void *extraParams);
/** /**
@ -228,36 +212,28 @@ ND4J_EXPORT void execPairwiseTransformBool(
*/ */
ND4J_EXPORT void execReduceFloat(Nd4jPointer *extraPointers, ND4J_EXPORT void execReduceFloat(Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo);
void *dZ, Nd4jLong *dZShapeInfo);
ND4J_EXPORT void execReduceSame(Nd4jPointer *extraPointers, ND4J_EXPORT void execReduceSame(Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo);
void *dZ, Nd4jLong *dZShapeInfo);
ND4J_EXPORT void execReduceBool(Nd4jPointer *extraPointers, ND4J_EXPORT void execReduceBool(Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo);
void *dZ, Nd4jLong *dZShapeInfo);
ND4J_EXPORT void execReduceLong(Nd4jPointer *extraPointers, ND4J_EXPORT void execReduceLong(Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo);
void *dZ, Nd4jLong *dZShapeInfo);
/** /**
* *
@ -270,46 +246,34 @@ ND4J_EXPORT void execReduceLong(Nd4jPointer *extraPointers,
*/ */
ND4J_EXPORT void execReduceFloat2(Nd4jPointer *extraPointers, ND4J_EXPORT void execReduceFloat2(Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape);
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape);
ND4J_EXPORT void execReduceSame2(Nd4jPointer *extraPointers, ND4J_EXPORT void execReduceSame2(Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape);
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape);
ND4J_EXPORT void execReduceBool2(Nd4jPointer *extraPointers, ND4J_EXPORT void execReduceBool2(Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape);
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape);
ND4J_EXPORT void execReduceLong2(Nd4jPointer *extraPointers, ND4J_EXPORT void execReduceLong2(Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape);
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape);
/** /**
* *
@ -324,13 +288,10 @@ ND4J_EXPORT void execReduceLong2(Nd4jPointer *extraPointers,
*/ */
ND4J_EXPORT void execReduce3(Nd4jPointer *extraPointers, ND4J_EXPORT void execReduce3(Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParamsVals, void *extraParamsVals,
void *hY, Nd4jLong *hYShapeInfo, OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo);
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo);
/** /**
* *
@ -343,13 +304,10 @@ ND4J_EXPORT void execReduce3(Nd4jPointer *extraPointers,
*/ */
ND4J_EXPORT void execReduce3Scalar(Nd4jPointer *extraPointers, ND4J_EXPORT void execReduce3Scalar(Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParamsVals, void *extraParamsVals,
void *hY, Nd4jLong *hYShapeInfo, OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo);
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo);
/** /**
* *
* @param opNum * @param opNum
@ -365,30 +323,22 @@ ND4J_EXPORT void execReduce3Scalar(Nd4jPointer *extraPointers,
*/ */
ND4J_EXPORT void execReduce3Tad(Nd4jPointer *extraPointers, ND4J_EXPORT void execReduce3Tad(Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParamsVals, void *extraParamsVals,
void *hY, Nd4jLong *hYShapeInfo, OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape,
void *dZ, Nd4jLong *dZShapeInfo,
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *yTadOnlyShapeInfo, Nd4jLong *yTadOffsets); Nd4jLong *yTadOnlyShapeInfo, Nd4jLong *yTadOffsets);
ND4J_EXPORT void execReduce3All(Nd4jPointer *extraPointers, ND4J_EXPORT void execReduce3All(Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParamsVals, void *extraParamsVals,
void *hY, Nd4jLong *hYShapeInfo, OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape,
void *dZ, Nd4jLong *dZShapeInfo,
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape,
Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets,
Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets); Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets);
@ -405,22 +355,16 @@ ND4J_EXPORT void execReduce3All(Nd4jPointer *extraPointers,
*/ */
ND4J_EXPORT void execScalar(Nd4jPointer *extraPointers, ND4J_EXPORT void execScalar(Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, OpaqueDataBuffer *dbScalar, Nd4jLong *hSscalarShapeInfo, Nd4jLong *dSscalarShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hScalar, Nd4jLong *hSscalarShapeInfo,
void *dScalar, Nd4jLong *dSscalarShapeInfo,
void *extraParams); void *extraParams);
ND4J_EXPORT void execScalarBool(Nd4jPointer *extraPointers, ND4J_EXPORT void execScalarBool(Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, OpaqueDataBuffer *dbScalar, Nd4jLong *hSscalarShapeInfo, Nd4jLong *dSscalarShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hScalar, Nd4jLong *hSscalarShapeInfo,
void *dScalar, Nd4jLong *dSscalarShapeInfo,
void *extraParams); void *extraParams);
/** /**
@ -432,11 +376,9 @@ ND4J_EXPORT void execScalarBool(Nd4jPointer *extraPointers,
*/ */
ND4J_EXPORT void execSummaryStatsScalar(Nd4jPointer *extraPointers, ND4J_EXPORT void execSummaryStatsScalar(Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
bool biasCorrected); bool biasCorrected);
/** /**
* *
@ -449,11 +391,9 @@ ND4J_EXPORT void execSummaryStatsScalar(Nd4jPointer *extraPointers,
*/ */
ND4J_EXPORT void execSummaryStats(Nd4jPointer *extraPointers, ND4J_EXPORT void execSummaryStats(Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
bool biasCorrected); bool biasCorrected);
/** /**
* *
@ -468,13 +408,10 @@ ND4J_EXPORT void execSummaryStats(Nd4jPointer *extraPointers,
*/ */
ND4J_EXPORT void execSummaryStatsTad(Nd4jPointer *extraPointers, ND4J_EXPORT void execSummaryStatsTad(Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape,
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape,
bool biasCorrected, bool biasCorrected,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets); Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
@ -490,42 +427,32 @@ ND4J_EXPORT void execSummaryStatsTad(Nd4jPointer *extraPointers,
*/ */
ND4J_EXPORT void execTransformFloat(Nd4jPointer *extraPointers, ND4J_EXPORT void execTransformFloat(Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams); void *extraParams);
ND4J_EXPORT void execTransformSame(Nd4jPointer *extraPointers, ND4J_EXPORT void execTransformSame(Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams); void *extraParams);
ND4J_EXPORT void execTransformBool(Nd4jPointer *extraPointers, ND4J_EXPORT void execTransformBool(Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams); void *extraParams);
ND4J_EXPORT void execTransformAny(Nd4jPointer *extraPointers, ND4J_EXPORT void execTransformAny(Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams); void *extraParams);
ND4J_EXPORT void execTransformStrict(Nd4jPointer *extraPointers, ND4J_EXPORT void execTransformStrict(Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams); void *extraParams);
/** /**
@ -543,29 +470,21 @@ ND4J_EXPORT void execTransformStrict(Nd4jPointer *extraPointers,
*/ */
ND4J_EXPORT void execScalarTad(Nd4jPointer *extraPointers, ND4J_EXPORT void execScalarTad(Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, OpaqueDataBuffer *dbScalars, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hScalars, Nd4jLong *hScalarShapeInfo,
void *dScalars, Nd4jLong *dScalarShapeInfo,
void *extraParams, void *extraParams,
void *hDimension, Nd4jLong *hDimensionShape, OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ); Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);
ND4J_EXPORT void execScalarBoolTad(Nd4jPointer *extraPointers, ND4J_EXPORT void execScalarBoolTad(Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, OpaqueDataBuffer *dbScalars, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hScalars, Nd4jLong *hScalarShapeInfo,
void *dScalars, Nd4jLong *dScalarShapeInfo,
void *extraParams, void *extraParams,
void *hDimension, Nd4jLong *hDimensionShape, OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ); Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);
@ -904,10 +823,8 @@ ND4J_EXPORT void deleteTadPack(OpaqueTadPack* ptr);
* @param zTadOffsets * @param zTadOffsets
*/ */
ND4J_EXPORT void pullRows(Nd4jPointer *extraPointers, ND4J_EXPORT void pullRows(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong *xShapeInfo, Nd4jLong *dxShapeInfo,
void *dx, Nd4jLong *dxShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *zShapeInfo, Nd4jLong *dzShapeInfo,
void *z, Nd4jLong *zShapeInfo,
void *dz, Nd4jLong *dzShapeInfo,
Nd4jLong n, Nd4jLong n,
Nd4jLong *indexes, Nd4jLong *indexes,
Nd4jLong *tadShapeInfo, Nd4jLong *tadShapeInfo,
@ -1086,8 +1003,7 @@ ND4J_EXPORT void execAggregateBatch(Nd4jPointer *extraPointers,
ND4J_EXPORT void execRandom(Nd4jPointer *extraPointers, ND4J_EXPORT void execRandom(Nd4jPointer *extraPointers,
int opNum, int opNum,
Nd4jPointer state, Nd4jPointer state,
void *hZ, Nd4jLong *hZShapeBuffer, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeBuffer, Nd4jLong *dZShapeBuffer,
void *dZ, Nd4jLong *dZShapeBuffer,
void *extraArguments); void *extraArguments);
/** /**
@ -1106,12 +1022,9 @@ ND4J_EXPORT void execRandom(Nd4jPointer *extraPointers,
ND4J_EXPORT void execRandom3(Nd4jPointer *extraPointers, ND4J_EXPORT void execRandom3(Nd4jPointer *extraPointers,
int opNum, int opNum,
Nd4jPointer state, Nd4jPointer state,
void *hX, Nd4jLong *hXShapeBuffer, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeBuffer, Nd4jLong *dXShapeBuffer,
void *dX, Nd4jLong *dXShapeBuffer, OpaqueDataBuffer *dbY, Nd4jLong *hYShapeBuffer, Nd4jLong *dYShapeBuffer,
void *hY, Nd4jLong *hYShapeBuffer, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeBuffer, Nd4jLong *dZShapeBuffer,
void *dY, Nd4jLong *dYShapeBuffer,
void *hZ, Nd4jLong *hZShapeBuffer,
void *dZ, Nd4jLong *dZShapeBuffer,
void *extraArguments); void *extraArguments);
/** /**
@ -1128,10 +1041,8 @@ ND4J_EXPORT void execRandom3(Nd4jPointer *extraPointers,
ND4J_EXPORT void execRandom2(Nd4jPointer *extraPointers, ND4J_EXPORT void execRandom2(Nd4jPointer *extraPointers,
int opNum, int opNum,
Nd4jPointer state, Nd4jPointer state,
void *hX, Nd4jLong *hXShapeBuffer, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeBuffer, Nd4jLong *dXShapeBuffer,
void *dX, Nd4jLong *dXShapeBuffer, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeBuffer, Nd4jLong *dZShapeBuffer,
void *hZ, Nd4jLong *hZShapeBuffer,
void *dZ, Nd4jLong *dZShapeBuffer,
void *extraArguments); void *extraArguments);
@ -1174,52 +1085,6 @@ ND4J_EXPORT void reSeedBuffer(Nd4jPointer *extraPointers,
*/ */
ND4J_EXPORT void destroyRandom(Nd4jPointer ptrRandom); ND4J_EXPORT void destroyRandom(Nd4jPointer ptrRandom);
/**
* Grid operations
*/
/**
*
* @param extras
* @param opTypeA
* @param opNumA
* @param opTypeB
* @param opNumB
* @param N
* @param dx
* @param xShapeInfo
* @param dy
* @param yShapeInfo
* @param dz
* @param zShapeInfo
* @param extraA
* @param extraB
* @param scalarA
* @param scalarB
*/
/*
ND4J_EXPORT void execMetaPredicateShape(Nd4jPointer *extras,
const int opTypeA,
const int opNumA,
const int opTypeB,
const int opNumB,
Nd4jLong N,
void *hX, Nd4jLong *hXShapeBuffer,
void *dX, Nd4jLong *dXShapeBuffer,
void *hY, Nd4jLong *hYShapeBuffer,
void *dY, Nd4jLong *dYShapeBuffer,
void *hZ, Nd4jLong *hZShapeBuffer,
void *dZ, Nd4jLong *dZShapeBuffer,
void *extraA,
void *extraB,
double scalarA,
double scalarB);
*/
} }
/** /**
@ -1561,8 +1426,7 @@ ND4J_EXPORT Nd4jPointer pointerForAddress(Nd4jLong address);
* @return * @return
*/ */
ND4J_EXPORT void tear(Nd4jPointer *extraPointers, ND4J_EXPORT void tear(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo, OpaqueDataBuffer *dbX, Nd4jLong *xShapeInfo, Nd4jLong *dxShapeInfo,
void *dx, Nd4jLong *dxShapeInfo,
Nd4jPointer *targets, Nd4jLong *zShapeInfo, Nd4jPointer *targets, Nd4jLong *zShapeInfo,
Nd4jLong *tadShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffsets); Nd4jLong *tadOffsets);
@ -1739,6 +1603,8 @@ ND4J_EXPORT void markGraphContextInplace(OpaqueContext* ptr, bool reallyInplace)
ND4J_EXPORT void setGraphContextCudaContext(OpaqueContext* ptr, void *stream, void *reductionPointer, void *allocationPointer); ND4J_EXPORT void setGraphContextCudaContext(OpaqueContext* ptr, void *stream, void *reductionPointer, void *allocationPointer);
ND4J_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo); ND4J_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
ND4J_EXPORT void setGraphContextOutputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo); ND4J_EXPORT void setGraphContextOutputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
ND4J_EXPORT void setGraphContextInputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo);
ND4J_EXPORT void setGraphContextOutputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo);
ND4J_EXPORT void setGraphContextTArguments(OpaqueContext* ptr, double *arguments, int numberOfArguments); ND4J_EXPORT void setGraphContextTArguments(OpaqueContext* ptr, double *arguments, int numberOfArguments);
ND4J_EXPORT void setGraphContextIArguments(OpaqueContext* ptr, Nd4jLong *arguments, int numberOfArguments); ND4J_EXPORT void setGraphContextIArguments(OpaqueContext* ptr, Nd4jLong *arguments, int numberOfArguments);
ND4J_EXPORT void setGraphContextBArguments(OpaqueContext* ptr, bool *arguments, int numberOfArguments); ND4J_EXPORT void setGraphContextBArguments(OpaqueContext* ptr, bool *arguments, int numberOfArguments);
@ -1766,6 +1632,28 @@ ND4J_EXPORT Nd4jPointer lcCopyStream(OpaqueLaunchContext* lc);
ND4J_EXPORT Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc); ND4J_EXPORT Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc);
ND4J_EXPORT Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc); ND4J_EXPORT Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc);
ND4J_EXPORT OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth);
ND4J_EXPORT OpaqueDataBuffer* dbCreateView(OpaqueDataBuffer *dataBuffer, Nd4jLong length, Nd4jLong offset);
ND4J_EXPORT Nd4jPointer dbPrimaryBuffer(OpaqueDataBuffer *dataBuffer);
ND4J_EXPORT Nd4jPointer dbSpecialBuffer(OpaqueDataBuffer *dataBuffer);
ND4J_EXPORT void dbExpandBuffer(OpaqueDataBuffer *dataBuffer, Nd4jLong elements);
ND4J_EXPORT void dbAllocatePrimaryBuffer(OpaqueDataBuffer *dataBuffer);
ND4J_EXPORT void dbAllocateSpecialBuffer(OpaqueDataBuffer *dataBuffer);
ND4J_EXPORT void dbSetPrimaryBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer primaryBuffer, Nd4jLong numBytes);
ND4J_EXPORT void dbSetSpecialBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer specialBuffer, Nd4jLong numBytes);
ND4J_EXPORT void dbSyncToSpecial(OpaqueDataBuffer *dataBuffer);
ND4J_EXPORT void dbSyncToPrimary(OpaqueDataBuffer *dataBuffer);
ND4J_EXPORT int dbLocality(OpaqueDataBuffer *dataBuffer);
ND4J_EXPORT int dbDeviceId(OpaqueDataBuffer *dataBuffer);
ND4J_EXPORT void dbSetDeviceId(OpaqueDataBuffer *dataBuffer, int deviceId);
ND4J_EXPORT void dbTickHostRead(OpaqueDataBuffer *dataBuffer);
ND4J_EXPORT void dbTickHostWrite(OpaqueDataBuffer *dataBuffer);
ND4J_EXPORT void dbTickDeviceRead(OpaqueDataBuffer *dataBuffer);
ND4J_EXPORT void dbTickDeviceWrite(OpaqueDataBuffer *dataBuffer);
ND4J_EXPORT void dbClose(OpaqueDataBuffer *dataBuffer);
ND4J_EXPORT void deleteDataBuffer(OpaqueDataBuffer *dataBuffer);
ND4J_EXPORT void dbExpand(OpaqueDataBuffer *dataBuffer, Nd4jLong elements);
ND4J_EXPORT int binaryLevel(); ND4J_EXPORT int binaryLevel();
ND4J_EXPORT int optimalLevel(); ND4J_EXPORT int optimalLevel();

View File

@ -184,16 +184,16 @@ void NDArray::synchronize(const char* msg) const {
// no-op // no-op
} }
void NDArray::prepareSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) { void NDArray::prepareSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables) {
// no-op // no-op
} }
void NDArray::registerSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) { void NDArray::registerSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList) {
// no-op // no-op
} }
void NDArray::preparePrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) { void NDArray::preparePrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables) {
// no-op // no-op
} }
void NDArray::registerPrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) { void NDArray::registerPrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList) {
// no-op // no-op
} }

File diff suppressed because it is too large Load Diff

View File

@ -236,7 +236,7 @@ void NDArray::synchronize(const char* msg) const {
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NDArray::prepareSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) { void NDArray::prepareSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables) {
for (const auto& a : readList) for (const auto& a : readList)
if(a != nullptr) if(a != nullptr)
@ -252,7 +252,7 @@ void NDArray::prepareSpecialUse(const std::initializer_list<const NDArray*>& wri
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NDArray::registerSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) { void NDArray::registerSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList) {
for (const auto& p : readList) for (const auto& p : readList)
if(p != nullptr) if(p != nullptr)
@ -264,7 +264,7 @@ void NDArray::registerSpecialUse(const std::initializer_list<const NDArray*>& wr
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NDArray::preparePrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) { void NDArray::preparePrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables) {
for (const auto& a : readList) for (const auto& a : readList)
if(a != nullptr) if(a != nullptr)
@ -280,7 +280,7 @@ void NDArray::preparePrimaryUse(const std::initializer_list<const NDArray*>& wri
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NDArray::registerPrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) { void NDArray::registerPrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList) {
for (const auto& p : readList) for (const auto& p : readList)
if(p != nullptr) if(p != nullptr)

File diff suppressed because it is too large Load Diff

View File

@ -34,10 +34,12 @@
#define ARRAY_SPARSE 2 #define ARRAY_SPARSE 2
#define ARRAY_COMPRESSED 4 #define ARRAY_COMPRESSED 4
#define ARRAY_EMPTY 8 #define ARRAY_EMPTY 8
#define ARRAY_RAGGED 16
#define ARRAY_CSR 16
#define ARRAY_CSC 32 #define ARRAY_CSR 32
#define ARRAY_COO 64 #define ARRAY_CSC 64
#define ARRAY_COO 128
// complex values // complex values
#define ARRAY_COMPLEX 512 #define ARRAY_COMPLEX 512
@ -72,8 +74,10 @@
// boolean values // boolean values
#define ARRAY_BOOL 524288 #define ARRAY_BOOL 524288
// utf-8 values // UTF values
#define ARRAY_STRING 1048576 #define ARRAY_UTF8 1048576
#define ARRAY_UTF16 4194304
#define ARRAY_UTF32 16777216
// flag for extras // flag for extras
#define ARRAY_EXTRAS 2097152 #define ARRAY_EXTRAS 2097152
@ -173,8 +177,12 @@ namespace nd4j {
return nd4j::DataType ::UINT32; return nd4j::DataType ::UINT32;
else if (hasPropertyBitSet(shapeInfo, ARRAY_LONG)) else if (hasPropertyBitSet(shapeInfo, ARRAY_LONG))
return nd4j::DataType ::UINT64; return nd4j::DataType ::UINT64;
else if (hasPropertyBitSet(shapeInfo, ARRAY_STRING)) else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF8))
return nd4j::DataType ::UTF8; return nd4j::DataType ::UTF8;
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF16))
return nd4j::DataType ::UTF16;
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF32))
return nd4j::DataType ::UTF32;
else { else {
//shape::printShapeInfoLinear("Bad unsigned datatype (not)stored in shape", const_cast<Nd4jLong*>(shapeInfo)); //shape::printShapeInfoLinear("Bad unsigned datatype (not)stored in shape", const_cast<Nd4jLong*>(shapeInfo));
#ifndef __CUDA_ARCH__ #ifndef __CUDA_ARCH__
@ -190,8 +198,12 @@ namespace nd4j {
return nd4j::DataType::INT32; return nd4j::DataType::INT32;
else if (hasPropertyBitSet(shapeInfo, ARRAY_LONG)) else if (hasPropertyBitSet(shapeInfo, ARRAY_LONG))
return nd4j::DataType::INT64; return nd4j::DataType::INT64;
else if (hasPropertyBitSet(shapeInfo, ARRAY_STRING)) else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF8))
return nd4j::DataType::UTF8; return nd4j::DataType::UTF8;
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF16))
return nd4j::DataType::UTF16;
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF32))
return nd4j::DataType::UTF32;
else { else {
//shape::printShapeInfoLinear("Bad signed datatype (not)stored in shape", const_cast<Nd4jLong*>(shapeInfo)); //shape::printShapeInfoLinear("Bad signed datatype (not)stored in shape", const_cast<Nd4jLong*>(shapeInfo));
#ifndef __CUDA_ARCH__ #ifndef __CUDA_ARCH__
@ -224,6 +236,8 @@ namespace nd4j {
return ArrayType::COMPRESSED; return ArrayType::COMPRESSED;
else if (hasPropertyBitSet(shapeInfo, ARRAY_EMPTY)) else if (hasPropertyBitSet(shapeInfo, ARRAY_EMPTY))
return ArrayType::EMPTY; return ArrayType::EMPTY;
else if (hasPropertyBitSet(shapeInfo, ARRAY_RAGGED))
return ArrayType::RAGGED;
else // by default we return DENSE type here else // by default we return DENSE type here
return ArrayType::DENSE; return ArrayType::DENSE;
} }
@ -333,7 +347,13 @@ namespace nd4j {
setPropertyBit(shapeInfo, ARRAY_LONG); setPropertyBit(shapeInfo, ARRAY_LONG);
break; break;
case nd4j::DataType::UTF8: case nd4j::DataType::UTF8:
setPropertyBit(shapeInfo, ARRAY_STRING); setPropertyBit(shapeInfo, ARRAY_UTF8);
break;
case nd4j::DataType::UTF16:
setPropertyBit(shapeInfo, ARRAY_UTF16);
break;
case nd4j::DataType::UTF32:
setPropertyBit(shapeInfo, ARRAY_UTF32);
break; break;
default: default:
#ifndef __CUDA_ARCH__ #ifndef __CUDA_ARCH__

View File

@ -27,6 +27,7 @@ namespace nd4j {
SPARSE = 2, SPARSE = 2,
COMPRESSED = 3, COMPRESSED = 3,
EMPTY = 4, EMPTY = 4,
RAGGED = 5,
}; };
} }

View File

@ -36,13 +36,14 @@ class ND4J_EXPORT DataBuffer {
private: private:
void* _primaryBuffer; void* _primaryBuffer = nullptr;
void* _specialBuffer; void* _specialBuffer = nullptr;
size_t _lenInBytes; size_t _lenInBytes = 0;
DataType _dataType; DataType _dataType;
memory::Workspace* _workspace; memory::Workspace* _workspace = nullptr;
bool _isOwnerPrimary; bool _isOwnerPrimary;
bool _isOwnerSpecial; bool _isOwnerSpecial;
std::atomic<int> _deviceId;
#ifdef __CUDABLAS__ #ifdef __CUDABLAS__
mutable std::atomic<Nd4jLong> _counter; mutable std::atomic<Nd4jLong> _counter;
@ -55,9 +56,9 @@ class ND4J_EXPORT DataBuffer {
void setCountersToZero(); void setCountersToZero();
void copyCounters(const DataBuffer& other); void copyCounters(const DataBuffer& other);
void deleteSpecial(); void deleteSpecial();
FORCEINLINE void deletePrimary(); void deletePrimary();
FORCEINLINE void deleteBuffers(); void deleteBuffers();
FORCEINLINE void setAllocFlags(const bool isOwnerPrimary, const bool isOwnerSpecial = false); void setAllocFlags(const bool isOwnerPrimary, const bool isOwnerSpecial = false);
void allocateBuffers(const bool allocBoth = false); void allocateBuffers(const bool allocBoth = false);
void setSpecial(void* special, const bool isOwnerSpecial); void setSpecial(void* special, const bool isOwnerSpecial);
void copyBufferFromHost(const void* hostBuffer, size_t sizeToCopyinBytes = 0, const Nd4jLong offsetThis = 0, const Nd4jLong offsetHostBuffer = 0); void copyBufferFromHost(const void* hostBuffer, size_t sizeToCopyinBytes = 0, const Nd4jLong offsetThis = 0, const Nd4jLong offsetHostBuffer = 0);
@ -65,37 +66,38 @@ class ND4J_EXPORT DataBuffer {
public: public:
FORCEINLINE DataBuffer(void* primary, void* special, DataBuffer(void* primary, void* special,
const size_t lenInBytes, const DataType dataType, const size_t lenInBytes, const DataType dataType,
const bool isOwnerPrimary = false, const bool isOwnerSpecial = false, const bool isOwnerPrimary = false, const bool isOwnerSpecial = false,
memory::Workspace* workspace = nullptr); memory::Workspace* workspace = nullptr);
FORCEINLINE DataBuffer(void* primary, DataBuffer(void* primary,
const size_t lenInBytes, const DataType dataType, const size_t lenInBytes, const DataType dataType,
const bool isOwnerPrimary = false, const bool isOwnerPrimary = false,
memory::Workspace* workspace = nullptr); memory::Workspace* workspace = nullptr);
FORCEINLINE DataBuffer(const void* hostBuffer, // copies data from hostBuffer to own memory buffer DataBuffer(const void* hostBuffer, // copies data from hostBuffer to own memory buffer
const DataType dataType, const size_t lenInBytes, const DataType dataType, const size_t lenInBytes,
memory::Workspace* workspace = nullptr); memory::Workspace* workspace = nullptr);
FORCEINLINE DataBuffer(const size_t lenInBytes, const DataType dataType, memory::Workspace* workspace = nullptr, const bool allocBoth = false); DataBuffer(const size_t lenInBytes, const DataType dataType, memory::Workspace* workspace = nullptr, const bool allocBoth = false);
FORCEINLINE DataBuffer(const DataBuffer& other); DataBuffer(const DataBuffer& other);
FORCEINLINE DataBuffer(DataBuffer&& other); DataBuffer(DataBuffer&& other);
FORCEINLINE explicit DataBuffer(); explicit DataBuffer();
FORCEINLINE ~DataBuffer(); ~DataBuffer();
FORCEINLINE DataBuffer& operator=(const DataBuffer& other); DataBuffer& operator=(const DataBuffer& other);
FORCEINLINE DataBuffer& operator=(DataBuffer&& other) noexcept; DataBuffer& operator=(DataBuffer&& other) noexcept;
FORCEINLINE DataType getDataType(); DataType getDataType();
FORCEINLINE size_t getLenInBytes() const; void setDataType(DataType dataType);
size_t getLenInBytes() const;
FORCEINLINE void* primary(); void* primary();
FORCEINLINE void* special(); void* special();
FORCEINLINE void allocatePrimary(); void allocatePrimary();
void allocateSpecial(); void allocateSpecial();
void writePrimary() const; void writePrimary() const;
@ -105,6 +107,10 @@ class ND4J_EXPORT DataBuffer {
bool isPrimaryActual() const; bool isPrimaryActual() const;
bool isSpecialActual() const; bool isSpecialActual() const;
void expand(const uint64_t size);
int deviceId() const;
void setDeviceId(int deviceId);
void migrate(); void migrate();
template <typename T> FORCEINLINE T* primaryAsT(); template <typename T> FORCEINLINE T* primaryAsT();
@ -118,256 +124,28 @@ class ND4J_EXPORT DataBuffer {
void copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes = 0, const Nd4jLong offsetThis = 0, const Nd4jLong offsetOther = 0); void copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes = 0, const Nd4jLong offsetThis = 0, const Nd4jLong offsetOther = 0);
static void memcpy(const DataBuffer &dst, const DataBuffer &src); static void memcpy(const DataBuffer &dst, const DataBuffer &src);
void setPrimaryBuffer(void *buffer, size_t length);
void setSpecialBuffer(void *buffer, size_t length);
/**
* This method deletes buffers, if we're owners
*/
void close();
}; };
///// IMLEMENTATION OF INLINE METHODS ///// ///// IMLEMENTATION OF INLINE METHODS /////
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
// default constructor template <typename T>
DataBuffer::DataBuffer() { T* DataBuffer::primaryAsT() {
_primaryBuffer = nullptr;
_specialBuffer = nullptr;
_lenInBytes = 0;
_dataType = INT8;
_workspace = nullptr;
_isOwnerPrimary = false;
_isOwnerSpecial = false;
setCountersToZero();
}
////////////////////////////////////////////////////////////////////////
// copy constructor
DataBuffer::DataBuffer(const DataBuffer &other) {
throw std::runtime_error("DataBuffer copy constructor: we don't expect using of this constructor!");
_lenInBytes = other._lenInBytes;
_dataType = other._dataType;
_workspace = other._workspace;
_primaryBuffer = nullptr;
_specialBuffer = nullptr;
setCountersToZero();
allocateBuffers();
copyBufferFrom(other);
}
////////////////////////////////////////////////////////////////////////
DataBuffer::DataBuffer(void* primary, void* special,
const size_t lenInBytes, const DataType dataType,
const bool isOwnerPrimary, const bool isOwnerSpecial,
memory::Workspace* workspace) {
if (primary == nullptr && special == nullptr)
throw std::runtime_error("DataBuffer constructor: can't be initialized with both nullptr buffers !");
_primaryBuffer = primary;
_specialBuffer = special;
_lenInBytes = lenInBytes;
_dataType = dataType;
_workspace = workspace;
_isOwnerPrimary = isOwnerPrimary;
_isOwnerSpecial = isOwnerSpecial;
setCountersToZero();
if(primary != nullptr)
readPrimary();
if(special != nullptr)
readSpecial();
}
////////////////////////////////////////////////////////////////////////
DataBuffer::DataBuffer(void* primary, const size_t lenInBytes, const DataType dataType, const bool isOwnerPrimary, memory::Workspace* workspace):
DataBuffer(primary, nullptr, lenInBytes, dataType, isOwnerPrimary, false, workspace) {
syncToSpecial(true);
}
////////////////////////////////////////////////////////////////////////
// copies data from hostBuffer to own memory buffer
DataBuffer::DataBuffer(const void* hostBuffer, const DataType dataType, const size_t lenInBytes, memory::Workspace* workspace) {
if (hostBuffer == nullptr)
throw std::runtime_error("DataBuffer constructor: can't be initialized with nullptr host buffer !");
if (lenInBytes == 0)
throw std::runtime_error("DataBuffer constructor: can't be initialized with zero length !");
_primaryBuffer = nullptr;
_specialBuffer = nullptr;
_lenInBytes = lenInBytes;
_dataType = dataType;
_workspace = workspace;
setCountersToZero();
allocateBuffers();
copyBufferFromHost(hostBuffer, lenInBytes);
}
////////////////////////////////////////////////////////////////////////
DataBuffer::DataBuffer(const size_t lenInBytes, const DataType dataType, memory::Workspace* workspace, const bool allocBoth) {
_dataType = dataType;
_workspace = workspace;
_lenInBytes = lenInBytes;
_primaryBuffer = nullptr;
_specialBuffer = nullptr;
setCountersToZero();
if(lenInBytes != 0) {
allocateBuffers(allocBoth);
writeSpecial();
}
}
////////////////////////////////////////////////////////////////////////
// move constructor
DataBuffer::DataBuffer(DataBuffer&& other) {
_primaryBuffer = other._primaryBuffer;
_specialBuffer = other._specialBuffer;
_lenInBytes = other._lenInBytes;
_dataType = other._dataType;
_workspace = other._workspace;
_isOwnerPrimary = other._isOwnerPrimary;
_isOwnerSpecial = other._isOwnerSpecial;
copyCounters(other);
other._primaryBuffer = other._specialBuffer = nullptr;
other.setAllocFlags(false, false);
other._lenInBytes = 0;
}
////////////////////////////////////////////////////////////////////////
// assignment operator
DataBuffer& DataBuffer::operator=(const DataBuffer& other) {
if (this == &other)
return *this;
deleteBuffers();
_lenInBytes = other._lenInBytes;
_dataType = other._dataType;
_workspace = other._workspace;
allocateBuffers();
copyBufferFrom(other);
return *this;
}
////////////////////////////////////////////////////////////////////////
// move assignment operator
DataBuffer& DataBuffer::operator=(DataBuffer&& other) noexcept {
if (this == &other)
return *this;
deleteBuffers();
_primaryBuffer = other._primaryBuffer;
_specialBuffer = other._specialBuffer;
_lenInBytes = other._lenInBytes;
_dataType = other._dataType;
_workspace = other._workspace;
_isOwnerPrimary = other._isOwnerPrimary;
_isOwnerSpecial = other._isOwnerSpecial;
copyCounters(other);
other._primaryBuffer = other._specialBuffer = nullptr;
other.setAllocFlags(false, false);
other._lenInBytes = 0;
return *this;
}
////////////////////////////////////////////////////////////////////////
void* DataBuffer::primary() {
return _primaryBuffer;
}
////////////////////////////////////////////////////////////////////////
void* DataBuffer::special() {
return _specialBuffer;
}
////////////////////////////////////////////////////////////////////////
DataType DataBuffer::getDataType() {
return _dataType;
}
////////////////////////////////////////////////////////////////////////
size_t DataBuffer::getLenInBytes() const {
return _lenInBytes;
}
////////////////////////////////////////////////////////////////////////
template <typename T>
T* DataBuffer::primaryAsT() {
return reinterpret_cast<T*>(_primaryBuffer); return reinterpret_cast<T*>(_primaryBuffer);
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template <typename T> template <typename T>
T* DataBuffer::specialAsT() { T* DataBuffer::specialAsT() {
return reinterpret_cast<T*>(_specialBuffer); return reinterpret_cast<T*>(_specialBuffer);
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::allocatePrimary() {
if (_primaryBuffer == nullptr && getLenInBytes() > 0) {
ALLOCATE(_primaryBuffer, _workspace, getLenInBytes(), int8_t);
_isOwnerPrimary = true;
} }
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::setAllocFlags(const bool isOwnerPrimary, const bool isOwnerSpecial) {
_isOwnerPrimary = isOwnerPrimary;
_isOwnerSpecial = isOwnerSpecial;
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::deletePrimary() {
if(_isOwnerPrimary && _primaryBuffer != nullptr && getLenInBytes() != 0) {
auto p = reinterpret_cast<int8_t*>(_primaryBuffer);
RELEASE(p, _workspace);
_primaryBuffer = nullptr;
_isOwnerPrimary = false;
}
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::deleteBuffers() {
deletePrimary();
deleteSpecial();
_lenInBytes = 0;
}
////////////////////////////////////////////////////////////////////////
DataBuffer::~DataBuffer() {
deleteBuffers();
}
} }

View File

@ -42,6 +42,8 @@ namespace nd4j {
QINT16 = 16, QINT16 = 16,
BFLOAT16 = 17, BFLOAT16 = 17,
UTF8 = 50, UTF8 = 50,
UTF16 = 51,
UTF32 = 52,
ANY = 100, ANY = 100,
AUTO = 200, AUTO = 200,
}; };

View File

@ -0,0 +1,71 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <dll.h>
#include <array/DataBuffer.h>
#include <array/DataType.h>
#include <memory>
#ifndef LIBND4J_INTEROPDATABUFFER_H
#define LIBND4J_INTEROPDATABUFFER_H
namespace nd4j {
/**
* This class is a wrapper for DataBuffer, suitable for sharing DataBuffer between front-end and back-end languages
*/
class ND4J_EXPORT InteropDataBuffer {
private:
std::shared_ptr<DataBuffer> _dataBuffer;
uint64_t _offset = 0;
public:
InteropDataBuffer(InteropDataBuffer &dataBuffer, uint64_t length, uint64_t offset);
InteropDataBuffer(std::shared_ptr<DataBuffer> databuffer);
InteropDataBuffer(size_t elements, nd4j::DataType dtype, bool allocateBoth);
~InteropDataBuffer() = default;
#ifndef __JAVACPP_HACK__
std::shared_ptr<DataBuffer> getDataBuffer() const;
std::shared_ptr<DataBuffer> dataBuffer();
#endif
void* primary() const;
void* special() const;
uint64_t offset() const ;
void setOffset(uint64_t offset);
void setPrimary(void* ptr, size_t length);
void setSpecial(void* ptr, size_t length);
void expand(size_t newlength);
int deviceId() const;
void setDeviceId(int deviceId);
static void registerSpecialUse(const std::vector<const InteropDataBuffer*>& writeList, const std::vector<const InteropDataBuffer*>& readList);
static void prepareSpecialUse(const std::vector<const InteropDataBuffer*>& writeList, const std::vector<const InteropDataBuffer*>& readList, bool synchronizeWritables = false);
static void registerPrimaryUse(const std::vector<const InteropDataBuffer*>& writeList, const std::vector<const InteropDataBuffer*>& readList);
static void preparePrimaryUse(const std::vector<const InteropDataBuffer*>& writeList, const std::vector<const InteropDataBuffer*>& readList, bool synchronizeWritables = false);
};
}
#endif //LIBND4J_INTEROPDATABUFFER_H

View File

@ -23,6 +23,24 @@
#include <DataTypeUtils.h> #include <DataTypeUtils.h>
namespace nd4j { namespace nd4j {
void DataBuffer::expand(const uint64_t size) {
if (size > _lenInBytes) {
// allocate new buffer
int8_t *newBuffer = nullptr;
ALLOCATE(newBuffer, _workspace, size, int8_t);
// copy data from existing buffer
std::memcpy(newBuffer, _primaryBuffer, _lenInBytes);
if (_isOwnerPrimary) {
RELEASE(reinterpret_cast<int8_t *>(_primaryBuffer), _workspace);
}
_primaryBuffer = newBuffer;
_lenInBytes = size;
_isOwnerPrimary = true;
}
}
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void DataBuffer::setCountersToZero() { void DataBuffer::setCountersToZero() {
@ -99,14 +117,17 @@ void DataBuffer::allocateSpecial() {
void DataBuffer::migrate() { void DataBuffer::migrate() {
} }
///////////////////////////////////////////////////////////////////////
void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) {
if (src._lenInBytes < dst._lenInBytes)
throw std::runtime_error("DataBuffer::memcpy: Source data buffer is smaller than destination");
std::memcpy(dst._primaryBuffer, src._primaryBuffer, dst._lenInBytes); /////////////////////////
void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) {
if (src._lenInBytes > dst._lenInBytes)
throw std::runtime_error("DataBuffer::memcpy: Source data buffer is larger than destination");
std::memcpy(dst._primaryBuffer, src._primaryBuffer, src._lenInBytes);
dst.readPrimary();
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void DataBuffer::writePrimary() const { } void DataBuffer::writePrimary() const { }
void DataBuffer::writeSpecial() const { } void DataBuffer::writeSpecial() const { }

View File

@ -25,6 +25,40 @@
#include <exceptions/cuda_exception.h> #include <exceptions/cuda_exception.h>
namespace nd4j { namespace nd4j {
void DataBuffer::expand(const uint64_t size) {
if (size > _lenInBytes) {
// allocate new buffer
int8_t *newBuffer = nullptr;
int8_t *newSpecialBuffer = nullptr;
ALLOCATE_SPECIAL(newSpecialBuffer, _workspace, size, int8_t);
// copy data from existing buffer
if (_primaryBuffer != nullptr) {
// there's non-zero chance that primary buffer doesn't exist yet
ALLOCATE(newBuffer, _workspace, size, int8_t);
std::memcpy(newBuffer, _primaryBuffer, _lenInBytes);
if (_isOwnerPrimary) {
auto ipb = reinterpret_cast<int8_t *>(_primaryBuffer);
RELEASE(ipb, _workspace);
}
_primaryBuffer = newBuffer;
_isOwnerPrimary = true;
}
cudaMemcpy(newSpecialBuffer, _specialBuffer, _lenInBytes, cudaMemcpyDeviceToDevice);
if (_isOwnerSpecial) {
auto isb = reinterpret_cast<int8_t *>(_specialBuffer);
RELEASE_SPECIAL(isb, _workspace);
}
_specialBuffer = newSpecialBuffer;
_lenInBytes = size;
_isOwnerSpecial = true;
}
}
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void DataBuffer::allocateSpecial() { void DataBuffer::allocateSpecial() {
@ -37,8 +71,9 @@ void DataBuffer::allocateSpecial() {
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void DataBuffer::syncToPrimary(const LaunchContext* context, const bool forceSync) { void DataBuffer::syncToPrimary(const LaunchContext* context, const bool forceSync) {
if(isPrimaryActual() && !forceSync) if(isPrimaryActual() && !forceSync) {
return; return;
}
allocatePrimary(); allocatePrimary();
@ -46,7 +81,9 @@ void DataBuffer::syncToPrimary(const LaunchContext* context, const bool forceSyn
if (res != 0) if (res != 0)
throw cuda_exception::build("DataBuffer::syncToPrimary failed to to some previous kernel failre", res); throw cuda_exception::build("DataBuffer::syncToPrimary failed to to some previous kernel failre", res);
cudaMemcpy(_primaryBuffer, _specialBuffer, getLenInBytes(), cudaMemcpyDeviceToHost); res = cudaMemcpy(_primaryBuffer, _specialBuffer, getLenInBytes(), cudaMemcpyDeviceToHost);
if (res != 0)
throw cuda_exception::build("DataBuffer::syncToPrimary cudaMemcpy failed", res);
readPrimary(); readPrimary();
} }
@ -54,13 +91,19 @@ void DataBuffer::syncToPrimary(const LaunchContext* context, const bool forceSyn
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void DataBuffer::syncToSpecial(const bool forceSync) { void DataBuffer::syncToSpecial(const bool forceSync) {
// in this case there's nothing to do here
if(isSpecialActual() && !forceSync) if (_primaryBuffer == nullptr)
return; return;
if(isSpecialActual() && !forceSync) {
return;
}
allocateSpecial(); allocateSpecial();
cudaMemcpy(_specialBuffer, _primaryBuffer, getLenInBytes(), cudaMemcpyHostToDevice); auto res = cudaMemcpy(_specialBuffer, _primaryBuffer, getLenInBytes(), cudaMemcpyHostToDevice);
if (res != 0)
throw cuda_exception::build("DataBuffer::syncToSpecial cudaMemcpy failed", res);
readSpecial(); readSpecial();
} }
@ -97,19 +140,6 @@ void DataBuffer::copyCounters(const DataBuffer& other) {
_readPrimary.store(other._writeSpecial); _readPrimary.store(other._writeSpecial);
_readSpecial.store(other._writePrimary); _readSpecial.store(other._writePrimary);
} }
////////////////////////////////////////////////////////////////////////
void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) {
if (src._lenInBytes < dst._lenInBytes)
throw std::runtime_error("DataBuffer::memcpy: Source data buffer is smaller than destination");
if (src.isSpecialActual()) {
cudaMemcpy(dst._specialBuffer, src._specialBuffer, dst.getLenInBytes(), cudaMemcpyDeviceToDevice);
} else if (src.isPrimaryActual()) {
cudaMemcpy(dst._specialBuffer, src._primaryBuffer, dst.getLenInBytes(), cudaMemcpyHostToDevice);
}
dst.writeSpecial();
}
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void DataBuffer::copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes, const Nd4jLong offsetThis, const Nd4jLong offsetOther) { // copies only to special buffer void DataBuffer::copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes, const Nd4jLong offsetThis, const Nd4jLong offsetOther) { // copies only to special buffer
@ -176,8 +206,11 @@ void DataBuffer::allocateBuffers(const bool allocBoth) { // always allocate s
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void DataBuffer::setToZeroBuffers(const bool both) { void DataBuffer::setToZeroBuffers(const bool both) {
cudaMemsetAsync(special(), 0, getLenInBytes(), *LaunchContext::defaultContext()->getCudaStream());
auto res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaStream());
if (res != 0)
throw cuda_exception::build("DataBuffer::setToZeroBuffers: streamSync failed!", res);
cudaMemset(special(), 0, getLenInBytes());
writeSpecial(); writeSpecial();
if(both) { if(both) {
@ -186,12 +219,37 @@ void DataBuffer::setToZeroBuffers(const bool both) {
} }
} }
/////////////////////////
void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) {
if (src._lenInBytes > dst._lenInBytes)
throw std::runtime_error("DataBuffer::memcpy: Source data buffer is larger than destination");
int res = 0;
if (src.isSpecialActual()) {
res = cudaMemcpyAsync(dst._specialBuffer, src._specialBuffer, src.getLenInBytes(), cudaMemcpyDeviceToDevice, *LaunchContext::defaultContext()->getCudaStream());
} else if (src.isPrimaryActual()) {
res = cudaMemcpyAsync(dst._specialBuffer, src._primaryBuffer, src.getLenInBytes(), cudaMemcpyHostToDevice, *LaunchContext::defaultContext()->getCudaStream());
}
if (res != 0)
throw cuda_exception::build("DataBuffer::memcpy: cudaMemcpyAsync failed!", res);
res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaStream());
if (res != 0)
throw cuda_exception::build("DataBuffer::memcpy: streamSync failed!", res);
dst.writeSpecial();
}
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void DataBuffer::migrate() { void DataBuffer::migrate() {
memory::Workspace* newWorkspace = nullptr; memory::Workspace* newWorkspace = nullptr;
void* newBuffer; void* newBuffer;
ALLOCATE_SPECIAL(newBuffer, newWorkspace, getLenInBytes(), int8_t); ALLOCATE_SPECIAL(newBuffer, newWorkspace, getLenInBytes(), int8_t);
cudaMemcpy(newBuffer, _specialBuffer, getLenInBytes(), cudaMemcpyDeviceToDevice); auto res = cudaMemcpy(newBuffer, _specialBuffer, getLenInBytes(), cudaMemcpyDeviceToDevice);
if (res != 0)
throw cuda_exception::build("DataBuffer::migrate: cudaMemcpyAsync failed!", res);
if (_isOwnerSpecial) { if (_isOwnerSpecial) {
// now we're releasing original buffer // now we're releasing original buffer
@ -203,7 +261,7 @@ void DataBuffer::migrate() {
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void DataBuffer::writePrimary() const { _writePrimary = ++_counter; } void DataBuffer::writePrimary() const {_writePrimary = ++_counter; }
void DataBuffer::writeSpecial() const { _writeSpecial = ++_counter; } void DataBuffer::writeSpecial() const { _writeSpecial = ++_counter; }
void DataBuffer::readPrimary() const { _readPrimary = ++_counter; } void DataBuffer::readPrimary() const { _readPrimary = ++_counter; }
void DataBuffer::readSpecial() const { _readSpecial = ++_counter; } void DataBuffer::readSpecial() const { _readSpecial = ++_counter; }

View File

@ -0,0 +1,301 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include <array/DataBuffer.h>
#include <helpers/logger.h>
#include <array/DataTypeUtils.h>
#include <execution/AffinityManager.h>
namespace nd4j {
///// IMLEMENTATION OF COMMON METHODS /////
////////////////////////////////////////////////////////////////////////
// default constructor
DataBuffer::DataBuffer() {
_primaryBuffer = nullptr;
_specialBuffer = nullptr;
_lenInBytes = 0;
_dataType = INT8;
_workspace = nullptr;
_isOwnerPrimary = false;
_isOwnerSpecial = false;
_deviceId = nd4j::AffinityManager::currentDeviceId();
setCountersToZero();
}
////////////////////////////////////////////////////////////////////////
// copy constructor
DataBuffer::DataBuffer(const DataBuffer &other) {
throw std::runtime_error("DataBuffer copy constructor: we don't expect using of this constructor!");
_lenInBytes = other._lenInBytes;
_dataType = other._dataType;
_workspace = other._workspace;
_primaryBuffer = nullptr;
_specialBuffer = nullptr;
_deviceId.store(other._deviceId.load());
setCountersToZero();
allocateBuffers();
copyBufferFrom(other);
}
////////////////////////////////////////////////////////////////////////
DataBuffer::DataBuffer(void* primary, void* special,
const size_t lenInBytes, const DataType dataType,
const bool isOwnerPrimary, const bool isOwnerSpecial,
memory::Workspace* workspace) {
if (primary == nullptr && special == nullptr)
throw std::runtime_error("DataBuffer constructor: can't be initialized with both nullptr buffers !");
_primaryBuffer = primary;
_specialBuffer = special;
_lenInBytes = lenInBytes;
_dataType = dataType;
_workspace = workspace;
_isOwnerPrimary = isOwnerPrimary;
_isOwnerSpecial = isOwnerSpecial;
_deviceId = nd4j::AffinityManager::currentDeviceId();
setCountersToZero();
if(primary != nullptr)
readPrimary();
if(special != nullptr)
readSpecial();
}
////////////////////////////////////////////////////////////////////////
DataBuffer::DataBuffer(void* primary, const size_t lenInBytes, const DataType dataType, const bool isOwnerPrimary, memory::Workspace* workspace):
DataBuffer(primary, nullptr, lenInBytes, dataType, isOwnerPrimary, false, workspace) {
syncToSpecial(true);
}
////////////////////////////////////////////////////////////////////////
// copies data from hostBuffer to own memory buffer
DataBuffer::DataBuffer(const void* hostBuffer, const DataType dataType, const size_t lenInBytes, memory::Workspace* workspace) {
if (hostBuffer == nullptr)
throw std::runtime_error("DataBuffer constructor: can't be initialized with nullptr host buffer !");
if (lenInBytes == 0)
throw std::runtime_error("DataBuffer constructor: can't be initialized with zero length !");
_primaryBuffer = nullptr;
_specialBuffer = nullptr;
_lenInBytes = lenInBytes;
_dataType = dataType;
_workspace = workspace;
_deviceId = nd4j::AffinityManager::currentDeviceId();
setCountersToZero();
allocateBuffers();
copyBufferFromHost(hostBuffer, lenInBytes);
}
////////////////////////////////////////////////////////////////////////
DataBuffer::DataBuffer(const size_t lenInBytes, const DataType dataType, memory::Workspace* workspace, const bool allocBoth) {
_dataType = dataType;
_workspace = workspace;
_lenInBytes = lenInBytes;
_primaryBuffer = nullptr;
_specialBuffer = nullptr;
_deviceId = nd4j::AffinityManager::currentDeviceId();
setCountersToZero();
if(lenInBytes != 0) {
allocateBuffers(allocBoth);
writeSpecial();
}
}
////////////////////////////////////////////////////////////////////////
// move constructor
DataBuffer::DataBuffer(DataBuffer&& other) {
_primaryBuffer = other._primaryBuffer;
_specialBuffer = other._specialBuffer;
_lenInBytes = other._lenInBytes;
_dataType = other._dataType;
_workspace = other._workspace;
_isOwnerPrimary = other._isOwnerPrimary;
_isOwnerSpecial = other._isOwnerSpecial;
_deviceId.store(other._deviceId);
copyCounters(other);
other._primaryBuffer = other._specialBuffer = nullptr;
other.setAllocFlags(false, false);
other._lenInBytes = 0;
}
////////////////////////////////////////////////////////////////////////
// assignment operator
DataBuffer& DataBuffer::operator=(const DataBuffer& other) {
if (this == &other)
return *this;
deleteBuffers();
_lenInBytes = other._lenInBytes;
_dataType = other._dataType;
_workspace = other._workspace;
allocateBuffers();
copyBufferFrom(other);
return *this;
}
////////////////////////////////////////////////////////////////////////
// move assignment operator
DataBuffer& DataBuffer::operator=(DataBuffer&& other) noexcept {
if (this == &other)
return *this;
deleteBuffers();
_primaryBuffer = other._primaryBuffer;
_specialBuffer = other._specialBuffer;
_lenInBytes = other._lenInBytes;
_dataType = other._dataType;
_workspace = other._workspace;
_isOwnerPrimary = other._isOwnerPrimary;
_isOwnerSpecial = other._isOwnerSpecial;
copyCounters(other);
other._primaryBuffer = other._specialBuffer = nullptr;
other.setAllocFlags(false, false);
other._lenInBytes = 0;
return *this;
}
////////////////////////////////////////////////////////////////////////
void* DataBuffer::primary() {
return _primaryBuffer;
}
////////////////////////////////////////////////////////////////////////
void* DataBuffer::special() {
return _specialBuffer;
}
////////////////////////////////////////////////////////////////////////
DataType DataBuffer::getDataType() {
return _dataType;
}
////////////////////////////////////////////////////////////////////////
size_t DataBuffer::getLenInBytes() const {
return _lenInBytes;
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::allocatePrimary() {
if (_primaryBuffer == nullptr && getLenInBytes() > 0) {
ALLOCATE(_primaryBuffer, _workspace, getLenInBytes(), int8_t);
_isOwnerPrimary = true;
}
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::setAllocFlags(const bool isOwnerPrimary, const bool isOwnerSpecial) {
_isOwnerPrimary = isOwnerPrimary;
_isOwnerSpecial = isOwnerSpecial;
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::deletePrimary() {
if(_isOwnerPrimary && _primaryBuffer != nullptr && getLenInBytes() != 0) {
auto p = reinterpret_cast<int8_t*>(_primaryBuffer);
RELEASE(p, _workspace);
_primaryBuffer = nullptr;
_isOwnerPrimary = false;
}
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::deleteBuffers() {
deletePrimary();
deleteSpecial();
_lenInBytes = 0;
}
////////////////////////////////////////////////////////////////////////
DataBuffer::~DataBuffer() {
deleteBuffers();
}
void DataBuffer::setPrimaryBuffer(void *buffer, size_t length) {
if (_primaryBuffer != nullptr && _isOwnerPrimary) {
deletePrimary();
}
_primaryBuffer = buffer;
_isOwnerPrimary = false;
_lenInBytes = length * DataTypeUtils::sizeOf(_dataType);
}
void DataBuffer::setSpecialBuffer(void *buffer, size_t length) {
this->setSpecial(buffer, false);
_lenInBytes = length * DataTypeUtils::sizeOf(_dataType);
}
void DataBuffer::setDataType(DataType dataType) {
_dataType = dataType;
}
int DataBuffer::deviceId() const {
return _deviceId.load();
}
void DataBuffer::close() {
this->deleteBuffers();
}
void DataBuffer::setDeviceId(int deviceId) {
_deviceId = deviceId;
}
}

View File

@ -0,0 +1,146 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <array/InteropDataBuffer.h>
#include <array/DataTypeUtils.h>
#include <execution/AffinityManager.h>
#include <helpers/logger.h>
namespace nd4j {
InteropDataBuffer::InteropDataBuffer(InteropDataBuffer &dataBuffer, uint64_t length, uint64_t offset) {
_dataBuffer = dataBuffer.getDataBuffer();
// offset is always absolute to the original buffer
_offset = offset;
if (_offset + length > _dataBuffer->getLenInBytes()) {
throw std::runtime_error("offset + length is higher than original length");
}
}
InteropDataBuffer::InteropDataBuffer(std::shared_ptr<DataBuffer> databuffer) {
_dataBuffer = databuffer;
}
InteropDataBuffer::InteropDataBuffer(size_t elements, nd4j::DataType dtype, bool allocateBoth) {
if (elements == 0) {
_dataBuffer = std::make_shared<DataBuffer>();
_dataBuffer->setDataType(dtype);
} else {
_dataBuffer = std::make_shared<DataBuffer>(elements, dtype, nullptr, allocateBoth);
}
}
std::shared_ptr<DataBuffer> InteropDataBuffer::getDataBuffer() const {
return _dataBuffer;
}
std::shared_ptr<DataBuffer> InteropDataBuffer::dataBuffer() {
return _dataBuffer;
}
void* InteropDataBuffer::primary() const {
return reinterpret_cast<int8_t *>(_dataBuffer->primary()) + _offset;
}
void* InteropDataBuffer::special() const {
return reinterpret_cast<int8_t *>(_dataBuffer->special()) + _offset;
}
void InteropDataBuffer::setPrimary(void* ptr, size_t length) {
_dataBuffer->setPrimaryBuffer(ptr, length);
}
void InteropDataBuffer::setSpecial(void* ptr, size_t length) {
_dataBuffer->setSpecialBuffer(ptr, length);
}
uint64_t InteropDataBuffer::offset() const {
return _offset;
}
void InteropDataBuffer::setOffset(uint64_t offset) {
_offset = offset;
}
int InteropDataBuffer::deviceId() const {
return _dataBuffer->deviceId();
}
void InteropDataBuffer::registerSpecialUse(const std::vector<const InteropDataBuffer*>& writeList, const std::vector<const InteropDataBuffer*>& readList) {
for (const auto &v:writeList) {
if (v == nullptr)
continue;
v->getDataBuffer()->writeSpecial();
}
}
void InteropDataBuffer::prepareSpecialUse(const std::vector<const InteropDataBuffer*>& writeList, const std::vector<const InteropDataBuffer*>& readList, bool synchronizeWritables) {
auto currentDeviceId = nd4j::AffinityManager::currentDeviceId();
for (const auto &v:readList) {
if (v == nullptr)
continue;
if (v->getDataBuffer()->deviceId() != currentDeviceId)
v->getDataBuffer()->migrate();
v->getDataBuffer()->syncToSpecial();
}
// we don't tick write list, only ensure the same device affinity
for (const auto &v:writeList) {
if (v == nullptr)
continue;
// special case for legacy ops - views can be updated on host side, thus original array can be not updated
if (!v->getDataBuffer()->isSpecialActual())
v->getDataBuffer()->syncToSpecial();
if (v->getDataBuffer()->deviceId() != currentDeviceId)
v->getDataBuffer()->migrate();
}
}
void InteropDataBuffer::registerPrimaryUse(const std::vector<const InteropDataBuffer*>& writeList, const std::vector<const InteropDataBuffer*>& readList) {
for (const auto &v:writeList) {
if (v == nullptr)
continue;
}
}
void InteropDataBuffer::preparePrimaryUse(const std::vector<const InteropDataBuffer*>& writeList, const std::vector<const InteropDataBuffer*>& readList, bool synchronizeWritables) {
for (const auto &v:readList) {
if (v == nullptr)
continue;
v->getDataBuffer()->syncToPrimary(LaunchContext::defaultContext());
}
}
void InteropDataBuffer::expand(size_t newlength) {
_dataBuffer->expand(newlength * DataTypeUtils::sizeOf(_dataBuffer->getDataType()));
}
void InteropDataBuffer::setDeviceId(int deviceId) {
_dataBuffer->setDeviceId(deviceId);
}
}

View File

@ -138,7 +138,7 @@ namespace nd4j {
if (res != 0) if (res != 0)
throw cuda_exception::build("_reductionPointer allocation failed", res); throw cuda_exception::build("_reductionPointer allocation failed", res);
res = cudaMalloc(reinterpret_cast<void**>(&_scalarPointer), 16); res = cudaHostAlloc(reinterpret_cast<void**>(&_scalarPointer), 16, cudaHostAllocDefault);
if (res != 0) if (res != 0)
throw cuda_exception::build("_scalarPointer allocation failed", res); throw cuda_exception::build("_scalarPointer allocation failed", res);

View File

@ -185,9 +185,11 @@ namespace nd4j {
void setInputArray(int index, NDArray *array, bool removable = false); void setInputArray(int index, NDArray *array, bool removable = false);
void setInputArray(int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo); void setInputArray(int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
void setInputArray(int index, void *databuffer, void *shapeInfo, void *specialShapeInfo);
void setOutputArray(int index, NDArray *array, bool removable = false); void setOutputArray(int index, NDArray *array, bool removable = false);
void setOutputArray(int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo); void setOutputArray(int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
void setOutputArray(int index, void *databuffer, void *shapeInfo, void *specialShapeInfo);
void setTArguments(double *arguments, int numberOfArguments); void setTArguments(double *arguments, int numberOfArguments);
void setIArguments(Nd4jLong *arguments, int numberOfArguments); void setIArguments(Nd4jLong *arguments, int numberOfArguments);

View File

@ -21,6 +21,7 @@
#include <Context.h> #include <Context.h>
#include <helpers/ShapeUtils.h> #include <helpers/ShapeUtils.h>
#include <graph/Context.h> #include <graph/Context.h>
#include <array/InteropDataBuffer.h>
namespace nd4j { namespace nd4j {
@ -426,6 +427,44 @@ namespace nd4j {
array->setContext(_context); array->setContext(_context);
} }
void Context::setInputArray(int index, void *vdatabuffer, void *shapeInfo, void *specialShapeInfo) {
auto dataBuffer = reinterpret_cast<InteropDataBuffer*>(vdatabuffer);
if (_fastpath_in.size() < index + 1)
_fastpath_in.resize(index+1);
NDArray *array;
if (dataBuffer != nullptr)
array = new NDArray(dataBuffer->dataBuffer(), reinterpret_cast<Nd4jLong *>(shapeInfo), nd4j::LaunchContext::defaultContext(), dataBuffer->offset() / DataTypeUtils::sizeOf(ArrayOptions::dataType(reinterpret_cast<Nd4jLong *>(shapeInfo))));
else
array = new NDArray(nullptr, nullptr, reinterpret_cast<Nd4jLong *>(shapeInfo));
_fastpath_in[index] = array;
_handles.emplace_back(array);
if (_context != nullptr)
array->setContext(_context);
}
void Context::setOutputArray(int index, void *vdatabuffer, void *shapeInfo, void *specialShapeInfo) {
auto dataBuffer = reinterpret_cast<InteropDataBuffer*>(vdatabuffer);
if (_fastpath_out.size() < index + 1)
_fastpath_out.resize(index+1);
NDArray *array;
if (dataBuffer != nullptr)
array = new NDArray(dataBuffer->dataBuffer(), reinterpret_cast<Nd4jLong *>(shapeInfo), nd4j::LaunchContext::defaultContext(), dataBuffer->offset() / DataTypeUtils::sizeOf(ArrayOptions::dataType(reinterpret_cast<Nd4jLong *>(shapeInfo))));
else
array = new NDArray(nullptr, nullptr, reinterpret_cast<Nd4jLong *>(shapeInfo));
_fastpath_out[index] = array;
_handles.emplace_back(array);
if (_context != nullptr)
array->setContext(_context);
}
void Context::setTArguments(double *arguments, int numberOfArguments) { void Context::setTArguments(double *arguments, int numberOfArguments) {
_tArgs.clear(); _tArgs.clear();
_tArgs.reserve(numberOfArguments); _tArgs.reserve(numberOfArguments);

View File

@ -43,6 +43,8 @@ enum DType:byte {
QINT16, QINT16,
BFLOAT16 = 17, BFLOAT16 = 17,
UTF8 = 50, UTF8 = 50,
UTF16 = 51,
UTF32 = 52,
} }
// this structure describe NDArray // this structure describe NDArray

View File

@ -34,8 +34,6 @@
#include <driver_types.h> #include <driver_types.h>
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#define checkCudaErrors(ERR) if (ERR != 0) {throw std::runtime_error("CUDA stream synchronization failed");}
#endif #endif
#include <DebugInfo.h> #include <DebugInfo.h>
namespace nd4j { namespace nd4j {

View File

@ -25,6 +25,8 @@
#include <op_boilerplate.h> #include <op_boilerplate.h>
#include <string> #include <string>
#include <sstream> #include <sstream>
#include <vector>
#include <NDArray.h>
namespace nd4j { namespace nd4j {
class ND4J_EXPORT StringUtils { class ND4J_EXPORT StringUtils {
@ -53,6 +55,36 @@ namespace nd4j {
return result; return result;
} }
/**
* This method returns number of needle matches within haystack
* PLEASE NOTE: this method operates on 8-bit arrays interpreted as uint8
*
* @param haystack
* @param haystackLength
* @param needle
* @param needleLength
* @return
*/
static uint64_t countSubarrays(const void *haystack, uint64_t haystackLength, const void *needle, uint64_t needleLength);
/**
* This method returns number of bytes used for string NDArrays content
* PLEASE NOTE: this doesn't include header
*
* @param array
* @return
*/
static uint64_t byteLength(const NDArray &array);
/**
* This method splits a string into substring by delimiter
*
* @param haystack
* @param delimiter
* @return
*/
static std::vector<std::string> split(const std::string &haystack, const std::string &delimiter);
}; };
} }

View File

@ -19,7 +19,58 @@
// //
#include <helpers/StringUtils.h> #include <helpers/StringUtils.h>
#include <exceptions/datatype_exception.h>
namespace nd4j { namespace nd4j {
static FORCEINLINE bool match(const uint8_t *haystack, const uint8_t *needle, uint64_t length) {
for (int e = 0; e < length; e++)
if (haystack[e] != needle[e])
return false;
return true;
}
uint64_t StringUtils::countSubarrays(const void *vhaystack, uint64_t haystackLength, const void *vneedle, uint64_t needleLength) {
auto haystack = reinterpret_cast<const uint8_t*>(vhaystack);
auto needle = reinterpret_cast<const uint8_t*>(vneedle);
uint64_t number = 0;
for (uint64_t e = 0; e < haystackLength - needleLength; e++) {
if (match(&haystack[e], needle, needleLength))
number++;
}
return number;
}
uint64_t StringUtils::byteLength(const NDArray &array) {
if (!array.isS())
throw nd4j::datatype_exception::build("StringUtils::byteLength expects one of String types;", array.dataType());
uint64_t result = 0;
// our buffer stores offsets, and the last value is basically number of bytes used
auto buffer = array.bufferAsT<Nd4jLong>();
result = buffer[array.lengthOf()];
return result;
}
std::vector<std::string> StringUtils::split(const std::string &haystack, const std::string &delimiter) {
std::vector<std::string> output;
std::string::size_type prev_pos = 0, pos = 0;
// iterating through the haystack till the end
while((pos = haystack.find(delimiter, pos)) != std::string::npos) {
output.emplace_back(haystack.substr(prev_pos, pos-prev_pos));
prev_pos = ++pos;
}
output.emplace_back(haystack.substr(prev_pos, pos - prev_pos)); // Last word
return output;
}
} }

View File

@ -20,7 +20,6 @@
// //
#include <types/types.h> #include <types/types.h>
#include <ShapeUtils.h>
#include <op_boilerplate.h> #include <op_boilerplate.h>
#include <loops/reduce_bool.h> #include <loops/reduce_bool.h>
#include <loops/legacy_ops.h> #include <loops/legacy_ops.h>

View File

@ -20,7 +20,6 @@
// //
#include <types/types.h> #include <types/types.h>
#include <ShapeUtils.h>
#include <op_boilerplate.h> #include <op_boilerplate.h>
#include <loops/reduce_float.h> #include <loops/reduce_float.h>
#include <loops/legacy_ops.h> #include <loops/legacy_ops.h>

View File

@ -20,7 +20,6 @@
// //
#include <types/types.h> #include <types/types.h>
#include <ShapeUtils.h>
#include <op_boilerplate.h> #include <op_boilerplate.h>
#include <loops/reduce_long.h> #include <loops/reduce_long.h>
#include <loops/legacy_ops.h> #include <loops/legacy_ops.h>

View File

@ -20,7 +20,6 @@
// //
#include <types/types.h> #include <types/types.h>
#include <ShapeUtils.h>
#include <op_boilerplate.h> #include <op_boilerplate.h>
#include <loops/reduce_same.h> #include <loops/reduce_same.h>
#include <loops/legacy_ops.h> #include <loops/legacy_ops.h>

View File

@ -1624,4 +1624,9 @@
#define PARAMETRIC_D() [&] (Parameters &p) -> Context* #define PARAMETRIC_D() [&] (Parameters &p) -> Context*
#ifdef __CUDABLAS__
#define checkCudaErrors(ERR) if (ERR != 0) {throw std::runtime_error("CUDA stream synchronization failed");}
#endif
#endif #endif

View File

@ -40,6 +40,9 @@
#include <ops/declarable/headers/third_party.h> #include <ops/declarable/headers/third_party.h>
#include <ops/declarable/headers/tests.h> #include <ops/declarable/headers/tests.h>
#include <ops/declarable/headers/kernels.h> #include <ops/declarable/headers/kernels.h>
#include <ops/declarable/headers/strings.h>
#include <ops/declarable/headers/compat.h>
#include <ops/declarable/headers/util.h>
#include <ops/declarable/headers/BarnesHutTsne.h> #include <ops/declarable/headers/BarnesHutTsne.h>
#include <ops/declarable/headers/images.h> #include <ops/declarable/headers/images.h>
#include <dll.h> #include <dll.h>

View File

@ -0,0 +1 @@
This folder contains operations required for compatibility with TF and other frameworks.

View File

@ -0,0 +1,73 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_split_string)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/sparse_to_dense.h>
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(compat_sparse_to_dense, 4, 1, false, 0, 0) {
auto indices = INPUT_VARIABLE(0);
auto shape = INPUT_VARIABLE(1);
auto values = INPUT_VARIABLE(2);
NDArray *def = nullptr;
auto output = OUTPUT_VARIABLE(0);
if (block.width() > 3)
def = INPUT_VARIABLE(3);
nd4j::ops::helpers::compat_sparse_to_dense(*values, *indices, def, *output);
return Status::OK();
};
DECLARE_SHAPE_FN(compat_sparse_to_dense) {
auto indices = INPUT_VARIABLE(0);
auto shape = INPUT_VARIABLE(1);
auto values = INPUT_VARIABLE(2);
if (block.width() > 3) {
auto def = INPUT_VARIABLE(3);
REQUIRE_TRUE(def->dataType() == values->dataType() && def->isScalar(), 0, "compat_sparse_to_dense: default value must be a scalar of the same data type as actual values")
};
auto dtype = values->dataType();
// basically output shape is defined by the type of input, and desired shape input
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'c', shape->getBufferAsVector<Nd4jLong>()));
}
DECLARE_TYPES(compat_sparse_to_dense) {
getOpDescriptor()
->setAllowedInputTypes(0, {ALL_INTS}) // indices
->setAllowedInputTypes(1, {ALL_INTS}) // shape
->setAllowedInputTypes(2,nd4j::DataType::ANY) // sparse values
->setAllowedInputTypes(3,nd4j::DataType::ANY) // default value
->setAllowedOutputTypes(nd4j::DataType::ANY);
}
}
}
#endif

View File

@ -0,0 +1,140 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_split_string)
#include <ops/declarable/CustomOperations.h>
#include <helpers/StringUtils.h>
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(compat_string_split, 2, 2, false, 0, 0) {
auto input = INPUT_VARIABLE(0);
auto delim = INPUT_VARIABLE(1);
auto indices = OUTPUT_VARIABLE(0);
auto values = OUTPUT_VARIABLE(1);
auto d = delim->e<std::string>(0);
input->syncToHost();
delim->syncToHost();
// output rank N+1 wrt input rank
std::vector<Nd4jLong> ocoords(input->rankOf() + 1);
std::vector<Nd4jLong> icoords(input->rankOf());
// getting buffer lengths
// FIXME: it'll be bigger, since it'll include delimiters,
auto outputLength = StringUtils::byteLength(*input);
uint64_t ss = 0L;
Nd4jLong ic = 0L;
// loop through each string within tensor
for (auto e = 0L; e < input->lengthOf(); e++) {
// now we should map substring to indices
auto s = input->e<std::string>(e);
// getting base index
shape::index2coords(e, input->shapeInfo(), icoords.data());
// getting number of substrings
auto cnt = StringUtils::countSubarrays(s.c_str(), s.length(), d.c_str(), d.length()) + 1;
// filling output indices
for (uint64_t f = 0; f < cnt; f++) {
for (auto v: icoords)
indices->p(ic++, v);
// last index
indices->p(ic++, f);
}
ss += cnt;
}
// process strings now
std::vector<std::string> strings;
for (auto e = 0L; e < input->lengthOf(); e++) {
auto split = StringUtils::split(input->e<std::string>(e), d);
for (const auto &s:split)
strings.emplace_back(s);
}
// now once we have all strings in single vector time to fill
auto tmp = NDArrayFactory::string('c', {(Nd4jLong) strings.size()}, strings);
auto blen = StringUtils::byteLength(tmp) + ShapeUtils::stringBufferHeaderRequirements(strings.size());
// for CUDA mostly
values->dataBuffer()->allocatePrimary();
values->dataBuffer()->expand(blen);
memcpy(values->buffer(), tmp.buffer(), blen);
values->tickWriteHost();
// special case, for future use
indices->syncToDevice();
values->syncToDevice();
// we have to tick buffers
values->dataBuffer()->writePrimary();
values->dataBuffer()->readSpecial();
return Status::OK();
};
DECLARE_SHAPE_FN(compat_string_split) {
auto input = INPUT_VARIABLE(0);
auto delim = INPUT_VARIABLE(1);
auto d = delim->e<std::string>(0);
// count number of delimiter substrings in all strings within input tensor
uint64_t cnt = 0;
for (auto e = 0L; e < input->lengthOf(); e++) {
// FIXME: bad, not UTF-compatible
auto s = input->e<std::string>(e);
// each substring we see in haystack, splits string in two parts. so we should add 1 to the number of subarrays
cnt += StringUtils::countSubarrays(s.c_str(), s.length(), d.c_str(), d.length()) + 1;
}
// shape calculations
// virtual tensor rank will be N+1, for N rank input array, where data will be located at the biggest dimension
// values tensor is going to be vector always
// indices tensor is going to be vector with length equal to values.length * output rank
auto valuesShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(cnt, nd4j::DataType::UTF8);
auto indicesShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(cnt * (input->rankOf() + 1), nd4j::DataType::INT64);
return SHAPELIST(indicesShape, valuesShape);
}
DECLARE_TYPES(compat_string_split) {
getOpDescriptor()
->setAllowedInputTypes({ALL_STRINGS})
->setAllowedOutputTypes(0, {ALL_INDICES})
->setAllowedOutputTypes(1, {ALL_STRINGS});
}
}
}
#endif

View File

@ -47,8 +47,7 @@ namespace nd4j {
} }
// just memcpy data // just memcpy data
// output->dataBuffer()->copyBufferFrom(*input->dataBuffer()); // as variant DataBuffer::memcpy(*output->dataBuffer(), *input->dataBuffer());
DataBuffer::memcpy(*output->dataBuffer(), *input->dataBuffer()); // this is modern approach
return Status::OK(); return Status::OK();
} }

View File

@ -0,0 +1,50 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_split_string)
#include <ops/declarable/CustomOperations.h>
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(split_string, 2, 1, true, 0, 0) {
auto input = INPUT_VARIABLE(0);
auto delim = INPUT_VARIABLE(1);
return Status::OK();
};
DECLARE_SHAPE_FN(split_string) {
auto input = INPUT_VARIABLE(0);
auto delim = INPUT_VARIABLE(1);
return SHAPELIST();
}
DECLARE_TYPES(split_string) {
getOpDescriptor()
->setAllowedInputTypes({ALL_STRINGS})
->setAllowedOutputTypes({ALL_STRINGS});
}
}
}
#endif

View File

@ -0,0 +1,52 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_print_affinity)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/print_variable.h>
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(print_affinity, 1, 1, true, 0, 0) {
// TODO: make this op compatible with ArrayList etc
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
nd4j_printf("<Node %i>: Actuality: [HOST: %s; DEVICE: %s]; affinity: [%i]; Pointers: [HOST: %p; DEVICE: %p]; DataBuffer length: %lld\n", block.nodeId(), input->isActualOnHostSide() ? "true" : "false", input->isActualOnDeviceSide() ? "true" : "false", input->dataBuffer()->deviceId(), input->getBuffer(), input->getSpecialBuffer(), input->dataBuffer()->getLenInBytes());
return Status::OK();
}
DECLARE_TYPES(print_affinity) {
getOpDescriptor()
->setAllowedInputTypes(0, nd4j::DataType::ANY)
->setAllowedInputTypes(1, {ALL_STRINGS})
->setAllowedOutputTypes(0, nd4j::DataType::INT32);
}
DECLARE_SHAPE_FN(print_affinity) {
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT32));
}
}
}
#endif

View File

@ -0,0 +1,77 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_print_variable)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/print_variable.h>
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(print_variable, 1, 1, true, 0, 0) {
// TODO: make this op compatible with ArrayList etc
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
std::string str;
if (block.width() == 2) {
auto message = INPUT_VARIABLE(1);
REQUIRE_TRUE(message->isS(), 0, "print_variable: message variable must be a String");
str = message->e<std::string>(0);
}
bool printSpecial = false;
if (block.numB() > 0)
printSpecial = B_ARG(0);
if (printSpecial && !nd4j::Environment::getInstance()->isCPU()) {
// only specific backends support special printout. for cpu-based backends it's the same as regular print
if (block.width() == 2)
helpers::print_special(*block.launchContext(), *input, str);
else
helpers::print_special(*block.launchContext(), *input);
} else {
// optionally add message to the print out
if (block.width() == 2) {
input->printIndexedBuffer(str.c_str());
} else {
input->printIndexedBuffer();
}
}
return Status::OK();
}
DECLARE_TYPES(print_variable) {
getOpDescriptor()
->setAllowedInputTypes(0, nd4j::DataType::ANY)
->setAllowedInputTypes(1, {ALL_STRINGS})
->setAllowedOutputTypes(0, nd4j::DataType::INT32);
}
DECLARE_SHAPE_FN(print_variable) {
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT32));
}
}
}
#endif

View File

@ -0,0 +1,54 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef SAMEDIFF_COMPAT_H
#define SAMEDIFF_COMPAT_H
#include <ops/declarable/headers/common.h>
namespace nd4j {
namespace ops {
/**
* This operation splits input string into pieces separated by delimiter
* PLEASE NOTE: This implementation is compatible with TF 1.x
*
* Input[0] - string to split
* Input[1] - delimiter
*
* Returns:
* Output[0] - indices tensor
* Output[1] - values tensor
*/
#if NOT_EXCLUDED(OP_compat_string_split)
DECLARE_CUSTOM_OP(compat_string_split, 2, 2, false, 0, 0);
#endif
/**
* This operation converts TF sparse array representation to dense NDArray
*/
#if NOT_EXCLUDED(OP_compat_sparse_to_dense)
DECLARE_CUSTOM_OP(compat_sparse_to_dense, 4, 1, false, 0, 0);
#endif
}
}
#endif //SAMEDIFF_COMPAT_H

View File

@ -0,0 +1,42 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef SAMEDIFF_STRINGS_H
#define SAMEDIFF_STRINGS_H
#include <ops/declarable/headers/common.h>
namespace nd4j {
namespace ops {
/**
* This operation splits input string into pieces separated by delimiter
*
* Input[0] - string to split
* Input[1] - delimiter
*/
#if NOT_EXCLUDED(OP_split_string)
DECLARE_CUSTOM_OP(split_string, 2, 1, true, 0, 0);
#endif
}
}
#endif //SAMEDIFF_STRINGS_H

View File

@ -0,0 +1,44 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef LIBND4J_UTILS_H
#define LIBND4J_UTILS_H
#include <ops/declarable/headers/common.h>
namespace nd4j {
namespace ops {
/**
* This operation prints out NDArray content, either on host or device.
*/
#if NOT_EXCLUDED(OP_print_variable)
DECLARE_CUSTOM_OP(print_variable, 1, 1, true, 0, 0);
#endif
/**
* This operation prints out affinity & locality status of given NDArray
*/
#if NOT_EXCLUDED(OP_print_affinity)
DECLARE_CUSTOM_OP(print_affinity, 1, 1, true, 0, 0);
#endif
}
}
#endif //LIBND4J_UTILS_H

View File

@ -0,0 +1,31 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <ops/declarable/helpers/print_variable.h>
namespace nd4j {
namespace ops {
namespace helpers {
void print_special(LaunchContext &ctx, const NDArray &array, const std::string &message) {
array.printIndexedBuffer(message.c_str());
}
}
}
}

View File

@ -40,15 +40,11 @@ __global__ void preluCuda(const void *vx, const Nd4jLong *xShapeInfo,
const auto y = reinterpret_cast<const Y*>(vy); const auto y = reinterpret_cast<const Y*>(vy);
auto z = reinterpret_cast<X*>(vz); auto z = reinterpret_cast<X*>(vz);
__shared__ Nd4jLong xzLen, totalThreads, *sharedMem; __shared__ Nd4jLong xzLen;
__shared__ int xzRank, yRank; __shared__ int xzRank, yRank;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
xzLen = shape::length(xShapeInfo); xzLen = shape::length(xShapeInfo);
totalThreads = gridDim.x * blockDim.x;
xzRank = shape::rank(xShapeInfo); xzRank = shape::rank(xShapeInfo);
yRank = shape::rank(yShapeInfo); yRank = shape::rank(yShapeInfo);
@ -56,18 +52,15 @@ __global__ void preluCuda(const void *vx, const Nd4jLong *xShapeInfo,
__syncthreads(); __syncthreads();
const auto tid = blockIdx.x * blockDim.x + threadIdx.x; const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
Nd4jLong* coords = sharedMem + threadIdx.x * xzRank; Nd4jLong coords[MAX_RANK];
for (int i = tid; i < xzLen; i += totalThreads) {
for (int i = tid; i < xzLen; i += blockDim.x * gridDim.x) {
shape::index2coords(i, xShapeInfo, coords); shape::index2coords(i, xShapeInfo, coords);
const auto xzOffset = shape::getOffset(xShapeInfo, coords); const auto xzOffset = shape::getOffset(xShapeInfo, coords);
const auto xVal = x[xzOffset]; const auto xVal = x[xzOffset];
if(xVal < 0) { if(xVal < 0) {
for (uint j = 0; j < yRank; ++j) for (uint j = 0; j < yRank; ++j)
if(yShapeInfo[j + 1] == 1) if(yShapeInfo[j + 1] == 1)
coords[j + 1] = 0; coords[j + 1] = 0;
@ -82,7 +75,6 @@ __global__ void preluCuda(const void *vx, const Nd4jLong *xShapeInfo,
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
template<typename X, typename Y> template<typename X, typename Y>
linkage void preluCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz) { linkage void preluCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz) {
preluCuda<X, Y><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz); preluCuda<X, Y><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz);
} }
@ -91,9 +83,9 @@ void prelu(nd4j::LaunchContext * context, const NDArray& input, const NDArray& a
PointersManager manager(context, "prelu"); PointersManager manager(context, "prelu");
const int threadsPerBlock = MAX_NUM_THREADS / 2; const int threadsPerBlock = 256;
const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; const int blocksPerGrid = 512;
const int sharedMem = input.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; const int sharedMem = 512;
const auto xType = input.dataType(); const auto xType = input.dataType();
const auto yType = alpha.dataType(); const auto yType = alpha.dataType();
@ -119,13 +111,10 @@ __global__ linkage void preluBPCuda(const void *vIn, const Nd4jLong *inShapeI
auto dLdI = reinterpret_cast<Y*>(vdLdI); auto dLdI = reinterpret_cast<Y*>(vdLdI);
auto dLdA = reinterpret_cast<Y*>(vdLdA); auto dLdA = reinterpret_cast<Y*>(vdLdA);
__shared__ Nd4jLong inLen, totalThreads, *sharedMem; __shared__ Nd4jLong inLen, totalThreads;
__shared__ int inRank, alphaRank; __shared__ int inRank, alphaRank;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
inLen = shape::length(inShapeInfo); inLen = shape::length(inShapeInfo);
totalThreads = gridDim.x * blockDim.x; totalThreads = gridDim.x * blockDim.x;
@ -135,10 +124,9 @@ __global__ linkage void preluBPCuda(const void *vIn, const Nd4jLong *inShapeI
__syncthreads(); __syncthreads();
const auto tid = blockIdx.x * blockDim.x + threadIdx.x; const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
Nd4jLong* coords = sharedMem + threadIdx.x * inRank; Nd4jLong coords[MAX_RANK];
for (int i = tid; i < inLen; i += totalThreads) { for (int i = tid; i < inLen; i += totalThreads) {
shape::index2coords(i, inShapeInfo, coords); shape::index2coords(i, inShapeInfo, coords);
const auto inOffset = shape::getOffset(inShapeInfo, coords); const auto inOffset = shape::getOffset(inShapeInfo, coords);
@ -175,14 +163,13 @@ __host__ linkage void preluBPCudaLauncher(const int blocksPerGrid, const int thr
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void preluBP(nd4j::LaunchContext* context, const NDArray& input, const NDArray& alpha, const NDArray& dLdO, NDArray& dLdI, NDArray& dLdA) { void preluBP(nd4j::LaunchContext* context, const NDArray& input, const NDArray& alpha, const NDArray& dLdO, NDArray& dLdI, NDArray& dLdA) {
dLdA.nullify(); dLdA.nullify();
PointersManager manager(context, "preluBP"); PointersManager manager(context, "preluBP");
const int threadsPerBlock = MAX_NUM_THREADS / 2; const int threadsPerBlock = 256;
const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; const int blocksPerGrid = 512;
const int sharedMem = input.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; const int sharedMem = 512;
const auto xType = input.dataType(); const auto xType = input.dataType();
const auto zType = alpha.dataType(); const auto zType = alpha.dataType();

View File

@ -0,0 +1,61 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <ops/declarable/helpers/print_variable.h>
#include <helpers/PointersManager.h>
namespace nd4j {
namespace ops {
namespace helpers {
template <typename T>
static _CUDA_G void print_device(const void *special, const Nd4jLong *shapeInfo) {
auto length = shape::length(shapeInfo);
auto x = reinterpret_cast<const T*>(special);
// TODO: add formatting here
printf("[");
for (uint64_t e = 0; e < length; e++) {
printf("%f", (float) x[shape::getIndexOffset(e, shapeInfo)]);
if (e < length - 1)
printf(", ");
}
printf("]\n");
}
template <typename T>
static _CUDA_H void exec_print_device(LaunchContext &ctx, const void *special, const Nd4jLong *shapeInfo) {
print_device<T><<<1, 1, 1024, *ctx.getCudaStream()>>>(special, shapeInfo);
}
void print_special(LaunchContext &ctx, const NDArray &array, const std::string &message) {
NDArray::prepareSpecialUse({}, {&array});
PointersManager pm(&ctx, "print_device");
BUILD_SINGLE_SELECTOR(array.dataType(), exec_print_device, (ctx, array.getSpecialBuffer(), array.getSpecialShapeInfo()), LIBND4J_TYPES)
pm.synchronize();
NDArray::registerSpecialUse({}, {&array});
}
}
}
}

View File

@ -41,6 +41,9 @@
#include <helpers/DebugHelper.h> #include <helpers/DebugHelper.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <DebugHelper.h>
#endif // CUDACC #endif // CUDACC
#endif // LIBND4J_HELPERS_H #endif // LIBND4J_HELPERS_H

View File

@ -0,0 +1,123 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <ops/declarable/helpers/sparse_to_dense.h>
#include <helpers/StringUtils.h>
#include <helpers/ShapeUtils.h>
namespace nd4j {
namespace ops {
namespace helpers {
template <typename X, typename I>
static void fill_(const void *vvalues, const void *vindices, void *voutput, const Nd4jLong *zShapeInfo, uint8_t rank, uint64_t length) {
auto values = reinterpret_cast<const X*>(vvalues);
auto indices = reinterpret_cast<const I*>(vindices);
auto output = reinterpret_cast<X*>(voutput);
Nd4jLong coords[MAX_RANK];
uint64_t pos = 0;
for (uint64_t e = 0L; e < length; e++) {
// indices come in blocks
for (uint8_t p = 0; p < rank; p++) {
coords[p] = indices[pos++];
}
// fill output at given coords with sparse value
output[shape::getOffset(zShapeInfo, coords)] = values[e];
}
}
void compat_sparse_to_dense(const NDArray &values, const NDArray &indices, NDArray *def, NDArray &output) {
// make sure host buffer is updated
values.syncToHost();
indices.syncToHost();
auto rank = output.rankOf();
if (output.isS()) {
// string case is not so trivial, since elements might, and probably will, have different sizes
auto numValues = values.lengthOf();
auto numElements = output.lengthOf();
// first of all we calculate final buffer sizes and offsets
auto defaultLength = def == nullptr ? 0 : StringUtils::byteLength(*def);
auto valuesLength = StringUtils::byteLength(values);
auto bufferLength = defaultLength * (output.lengthOf() - numValues) + valuesLength;
auto headerLength = ShapeUtils::stringBufferHeaderRequirements(numElements);
// now we make sure our output buffer can hold results
output.dataBuffer()->expand( bufferLength + headerLength);
std::vector<Nd4jLong> outputCoords(rank);
std::vector<Nd4jLong> valueCoords(rank);
auto offsetsBuffer = output.bufferAsT<Nd4jLong>();
auto dataBuffer = reinterpret_cast<uint8_t*>(offsetsBuffer + output.lengthOf());
offsetsBuffer[0] = 0;
// getting initial value coords
for (int e = 0; e < rank; e++)
valueCoords[e] = indices.e<Nd4jLong>(e);
// write results individually
for (uint64_t e = 0; e < numElements; e++) {
auto vIndex = shape::coords2index(output.shapeInfo(), valueCoords.data());
auto cLength = 0L;
std::string str;
if (vIndex == e) {
// we're writing down sparse value here
str = values.e<std::string>(e);
} else {
// we're writing down default value if it exists
if (def != nullptr)
str = def->e<std::string>(0);
else
str = "";
}
// TODO: make it unicode compliant
memcpy(&dataBuffer[offsetsBuffer[e]], str.c_str(), str.length());
// writing down offset
offsetsBuffer[e+1] = cLength;
}
} else {
// numeric case is trivial, since all elements have equal sizes
// write out default values, if they are present
if (def != nullptr) {
output.assign(def);
// make sure output is synced back
output.syncToHost();
}
// write out values
BUILD_DOUBLE_SELECTOR(values.dataType(), indices.dataType(), fill_, (values.getBuffer(), indices.getBuffer(), output.buffer(), output.getShapeInfo(), rank, values.lengthOf()), LIBND4J_TYPES, INDEXING_TYPES);
}
// copy back to device, if there's any
output.syncToDevice();
}
}
}
}

View File

@ -0,0 +1,34 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef LIBND4J_PRINT_VARIABLE_H
#define LIBND4J_PRINT_VARIABLE_H
#include <ops/declarable/helpers/helpers.h>
namespace nd4j {
namespace ops {
namespace helpers {
void print_special(LaunchContext &ctx, const NDArray &array, const std::string &message = {});
}
}
}
#endif //LIBND4J_PRINT_VARIABLE_H

View File

@ -0,0 +1,34 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef SAMEDIFF_SPARSE_TO_DENSE_H
#define SAMEDIFF_SPARSE_TO_DENSE_H
#include <ops/declarable/helpers/helpers.h>
namespace nd4j {
namespace ops {
namespace helpers {
void compat_sparse_to_dense(const NDArray &values, const NDArray &indices, NDArray *def, NDArray &output);
}
}
}
#endif //SAMEDIFF_SPARSE_TO_DENSE_H

View File

@ -634,7 +634,7 @@
#define BROADCAST(NAME) nd4j::BroadcastOpsTuple::custom(nd4j::scalar::NAME, nd4j::pairwise::NAME, nd4j::broadcast::NAME) #define BROADCAST(NAME) nd4j::BroadcastOpsTuple::custom(nd4j::scalar::NAME, nd4j::pairwise::NAME, nd4j::broadcast::NAME)
#define BROADCAST_BOOL(NAME) nd4j::BroadcastBoolOpsTuple::custom(nd4j::scalar::NAME, nd4j::pairwise::NAME, nd4j::broadcast::NAME) #define BROADCAST_BOOL(NAME) nd4j::BroadcastBoolOpsTuple::custom(nd4j::scalar::NAME, nd4j::pairwise::NAME, nd4j::broadcast::NAME)
#define ALL_STRINGS nd4j::DataType::UTF8, nd4j::DataType::UTF16, nd4j::DataType::UTF32
#define ALL_INDICES nd4j::DataType::INT32, nd4j::DataType::INT64 #define ALL_INDICES nd4j::DataType::INT32, nd4j::DataType::INT64
#define ALL_INTS nd4j::DataType::INT8, nd4j::DataType::UINT8, nd4j::DataType::INT16, nd4j::DataType::UINT16, nd4j::DataType::INT32, nd4j::DataType::UINT32, nd4j::DataType::INT64, nd4j::DataType::UINT64 #define ALL_INTS nd4j::DataType::INT8, nd4j::DataType::UINT8, nd4j::DataType::INT16, nd4j::DataType::UINT16, nd4j::DataType::INT32, nd4j::DataType::UINT32, nd4j::DataType::INT64, nd4j::DataType::UINT64
#define ALL_FLOATS nd4j::DataType::HALF, nd4j::DataType::FLOAT32, nd4j::DataType::DOUBLE, nd4j::DataType::BFLOAT16 #define ALL_FLOATS nd4j::DataType::HALF, nd4j::DataType::FLOAT32, nd4j::DataType::DOUBLE, nd4j::DataType::BFLOAT16

View File

@ -810,9 +810,10 @@ TEST_F(DeclarableOpsTests12, pullRows_1) {
#ifdef __CUDABLAS__ #ifdef __CUDABLAS__
nativeStart[1] = (x.getContext()->getCudaStream()); nativeStart[1] = (x.getContext()->getCudaStream());
#endif #endif
OpaqueDataBuffer xBuf(x.dataBuffer());
pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), OpaqueDataBuffer zBuf(z.dataBuffer());
z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(), pullRows(nativeStart, &xBuf, x.getShapeInfo(), x.getSpecialShapeInfo(),
&zBuf, z.getShapeInfo(), z.specialShapeInfo(),
4, pidx, 4, pidx,
xTadPack.platformShapeInfo(), xTadPack.platformOffsets(), xTadPack.platformShapeInfo(), xTadPack.platformOffsets(),
zTadPack.platformShapeInfo(), zTadPack.platformOffsets()); zTadPack.platformShapeInfo(), zTadPack.platformOffsets());
@ -844,8 +845,10 @@ TEST_F(DeclarableOpsTests12, pullRows_2) {
#ifdef __CUDABLAS__ #ifdef __CUDABLAS__
nativeStart[1] = (x.getContext()->getCudaStream()); nativeStart[1] = (x.getContext()->getCudaStream());
#endif #endif
pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.specialBuffer(), x.specialShapeInfo(), OpaqueDataBuffer xBuf(x.dataBuffer());
z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(), OpaqueDataBuffer zBuf(z.dataBuffer());
pullRows(nativeStart, &xBuf, x.getShapeInfo(), x.specialShapeInfo(),
&zBuf, z.getShapeInfo(), z.specialShapeInfo(),
4, pidx, 4, pidx,
xTadPack.platformShapeInfo(), xTadPack.platformOffsets(), xTadPack.platformShapeInfo(), xTadPack.platformOffsets(),
zTadPack.platformShapeInfo(), zTadPack.platformOffsets()); zTadPack.platformShapeInfo(), zTadPack.platformOffsets());

View File

@ -0,0 +1,94 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "testlayers.h"
#include <ops/declarable/CustomOperations.h>
#include <NDArray.h>
#include <ops/ops.h>
#include <GradCheck.h>
#include <array>
using namespace nd4j;
class DeclarableOpsTests17 : public testing::Test {
public:
DeclarableOpsTests17() {
printf("\n");
fflush(stdout);
}
};
TEST_F(DeclarableOpsTests17, test_sparse_to_dense_1) {
auto values = NDArrayFactory::create<float>({1.f, 2.f, 3.f});
auto shape = NDArrayFactory::create<Nd4jLong>({3, 3});
auto ranges = NDArrayFactory::create<Nd4jLong>({0,0, 1,1, 2,2});
auto def = NDArrayFactory::create<float>(0.f);
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {1.f,0.f,0.f, 0.f,2.f,0.f, 0.f,0.f,3.f});
nd4j::ops::compat_sparse_to_dense op;
auto result = op.execute({&ranges, &shape, &values, &def}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
delete result;
}
TEST_F(DeclarableOpsTests17, test_sparse_to_dense_2) {
auto values = NDArrayFactory::string('c', {3}, {"alpha", "beta", "gamma"});
auto shape = NDArrayFactory::create<Nd4jLong>({3, 3});
auto ranges = NDArrayFactory::create<Nd4jLong>({0,0, 1,1, 2,2});
auto def = NDArrayFactory::string("d");
auto exp = NDArrayFactory::string('c', {3, 3}, {"alpha","d","d", "d","beta","d", "d","d","gamma"});
nd4j::ops::compat_sparse_to_dense op;
auto result = op.execute({&ranges, &shape, &values, &def}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
delete result;
}
TEST_F(DeclarableOpsTests17, test_compat_string_split_1) {
auto x = NDArrayFactory::string('c', {2}, {"first string", "second"});
auto delimiter = NDArrayFactory::string(" ");
auto exp0 = NDArrayFactory::create<Nd4jLong>({0,0, 0,1, 1,0});
auto exp1 = NDArrayFactory::string('c', {3}, {"first", "string", "second"});
nd4j::ops::compat_string_split op;
auto result = op.execute({&x, &delimiter}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(2, result->size());
auto z0 = result->at(0);
auto z1 = result->at(1);
ASSERT_TRUE(exp0.isSameShape(z0));
ASSERT_TRUE(exp1.isSameShape(z1));
ASSERT_EQ(exp0, *z0);
ASSERT_EQ(exp1, *z1);
delete result;
}

View File

@ -0,0 +1,52 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "testlayers.h"
#include <ops/declarable/CustomOperations.h>
#include <NDArray.h>
#include <ops/ops.h>
#include <GradCheck.h>
#include <array>
using namespace nd4j;
class DeclarableOpsTests18 : public testing::Test {
public:
DeclarableOpsTests18() {
printf("\n");
fflush(stdout);
}
};
TEST_F(DeclarableOpsTests18, test_bitcast_1) {
auto x = NDArrayFactory::create<double>(0.23028551377579154);
auto z = NDArrayFactory::create<Nd4jLong>(0);
auto e = NDArrayFactory::create<Nd4jLong>(4597464930322771456L);
nd4j::ops::bitcast op;
auto status = op.execute({&x}, {&z}, {}, {(Nd4jLong) nd4j::DataType::INT64}, {});
ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(e, z);
}

View File

@ -0,0 +1,40 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "testlayers.h"
#include <ops/declarable/CustomOperations.h>
#include <NDArray.h>
#include <ops/ops.h>
#include <GradCheck.h>
#include <array>
using namespace nd4j;
class DeclarableOpsTests19 : public testing::Test {
public:
DeclarableOpsTests19() {
printf("\n");
fflush(stdout);
}
};

View File

@ -834,12 +834,17 @@ TEST_F(JavaInteropTests, Test_Reduce3_EdgeCase) {
auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {0,1}); auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {0,1});
NDArray::prepareSpecialUse({&z}, {&x, &y, &dims}); NDArray::prepareSpecialUse({&z}, {&x, &y, &dims});
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer zBuf(z.dataBuffer());
OpaqueDataBuffer dimBuf(dims.dataBuffer());
execReduce3Tad(extraPointers, 2, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), execReduce3Tad(extraPointers, 2, &xBuf, x.shapeInfo(), x.specialShapeInfo(),
nullptr, nullptr,
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), &yBuf, y.shapeInfo(), y.specialShapeInfo(),
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &zBuf, z.shapeInfo(), z.specialShapeInfo(),
dims.buffer(), dims.shapeInfo(), dims.specialBuffer(), dims.specialShapeInfo(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); &dimBuf, dims.shapeInfo(), dims.specialShapeInfo(), packX.platformShapeInfo(),
packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
NDArray::registerSpecialUse({&z}, {&x, &y, &dims}); NDArray::registerSpecialUse({&z}, {&x, &y, &dims});
@ -981,10 +986,14 @@ TEST_F(JavaInteropTests, Test_Mixed_Add_1) {
NDArray::prepareSpecialUse({&arrayZ}, {&arrayX, &arrayY}); NDArray::prepareSpecialUse({&arrayZ}, {&arrayX, &arrayY});
OpaqueDataBuffer xBuf(arrayX.dataBuffer());
OpaqueDataBuffer yBuf(arrayY.dataBuffer());
OpaqueDataBuffer zBuf(arrayZ.dataBuffer());
execPairwiseTransform(nullptr, pairwise::Add, execPairwiseTransform(nullptr, pairwise::Add,
arrayX.buffer(), arrayX.shapeInfo(), arrayX.getSpecialBuffer(), arrayX.getSpecialShapeInfo(), &xBuf, arrayX.shapeInfo(), arrayX.getSpecialShapeInfo(),
arrayY.buffer(), arrayY.shapeInfo(), arrayY.getSpecialBuffer(), arrayY.getSpecialShapeInfo(), &yBuf, arrayY.shapeInfo(), arrayY.getSpecialShapeInfo(),
arrayZ.buffer(), arrayZ.shapeInfo(), arrayZ.getSpecialBuffer(), arrayZ.getSpecialShapeInfo(), &zBuf, arrayZ.shapeInfo(), arrayZ.getSpecialShapeInfo(),
nullptr); nullptr);
NDArray::registerSpecialUse({&arrayZ}, {&arrayX, &arrayY}); NDArray::registerSpecialUse({&arrayZ}, {&arrayX, &arrayY});
@ -1220,10 +1229,10 @@ TEST_F(JavaInteropTests, test_bfloat16_rng) {
auto z = NDArrayFactory::create<bfloat16>('c', {10}); auto z = NDArrayFactory::create<bfloat16>('c', {10});
RandomGenerator rng(119, 323841120L); RandomGenerator rng(119, 323841120L);
bfloat16 args[2] = {(bfloat16) 0.0f, (bfloat16) 1.0f}; bfloat16 args[2] = {(bfloat16) 0.0f, (bfloat16) 1.0f};
execRandom(nullptr, nd4j::random::Ops::UniformDistribution, &rng, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), args); OpaqueDataBuffer zBuf(z.dataBuffer());
execRandom(nullptr, nd4j::random::Ops::UniformDistribution, &rng, &zBuf, z.shapeInfo(), z.specialShapeInfo(), args);
//z.printIndexedBuffer("z"); //z.printIndexedBuffer("z");
ASSERT_TRUE(z.sumNumber().e<float>(0) > 0); ASSERT_TRUE(z.sumNumber().e<float>(0) > 0);
} }
@ -1267,6 +1276,64 @@ TEST_F(JavaInteropTests, test_size_dtype_1) {
ASSERT_EQ(e, z); ASSERT_EQ(e, z);
} }
TEST_F(JavaInteropTests, test_expandable_array_op_1) {
auto x = NDArrayFactory::string('c', {2}, {"first string", "second"});
auto d = NDArrayFactory::string(" ");
auto z0 = NDArrayFactory::create<Nd4jLong>('c', {6});
auto z1 = NDArrayFactory::string('c', {3}, {"", "", ""});
auto exp0 = NDArrayFactory::create<Nd4jLong>({0,0, 0,1, 1,0});
auto exp1 = NDArrayFactory::string('c', {3}, {"first", "string", "second"});
InteropDataBuffer iz0(z0.dataBuffer());
InteropDataBuffer iz1(z1.dataBuffer());
Context ctx(1);
ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
ctx.setInputArray(1, d.buffer(), d.shapeInfo(), d.specialBuffer(), d.specialShapeInfo());
ctx.setOutputArray(0, &iz0, z0.shapeInfo(), z0.specialShapeInfo());
ctx.setOutputArray(1, &iz1, z1.shapeInfo(), z1.specialShapeInfo());
nd4j::ops::compat_string_split op;
auto status = op.execute(&ctx);
ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(exp0, z0);
ASSERT_EQ(exp1, z1);
}
TEST_F(JavaInteropTests, test_workspace_backed_arrays_1) {
if (!Environment::getInstance()->isCPU())
return;
auto x = NDArrayFactory::create<double>('c', {4, 3, 4, 4});
auto y = NDArrayFactory::create<double>('c', {4, 3, 3, 3});
auto z = NDArrayFactory::create<double>('c', {4, 3, 4, 4});
double buffer[2048];
InteropDataBuffer ix(0, DataType::DOUBLE, false);
InteropDataBuffer iy(0, DataType::DOUBLE, false);
InteropDataBuffer iz(0, DataType::DOUBLE, false);
// we're imitating workspace-managed array here
ix.setPrimary(buffer + 64, x.lengthOf());
iy.setPrimary(buffer + 64 + x.lengthOf(), y.lengthOf());
iz.setPrimary(buffer + 64 + x.lengthOf() + y.lengthOf(), z.lengthOf());
Context ctx(1);
ctx.setInputArray(0, &ix, x.shapeInfo(), x.specialShapeInfo());
ctx.setInputArray(1, &iy, y.shapeInfo(), y.specialShapeInfo());
ctx.setOutputArray(0, &iz, z.shapeInfo(), z.specialShapeInfo());
ctx.setIArguments({2, 2, 1, 1, 0, 0, 1, 1, 0, 0, 0});
nd4j::ops::maxpool2d_bp op;
auto status = op.execute(&ctx);
ASSERT_EQ(Status::OK(), status);
}
/* /*
TEST_F(JavaInteropTests, Test_Results_Conversion_1) { TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
auto pl = nd4j::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb"); auto pl = nd4j::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb");

View File

@ -470,12 +470,16 @@ TEST_F(LegacyOpsTests, Reduce3_2) {
auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {1}); auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {1});
NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); NDArray::prepareSpecialUse({&z}, {&x, &y, &dim});
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer zBuf(z.dataBuffer());
OpaqueDataBuffer dimBuf(dim.dataBuffer());
execReduce3Tad(extraPointers, reduce3::CosineSimilarity, execReduce3Tad(extraPointers, reduce3::CosineSimilarity,
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), &xBuf, x.shapeInfo(), x.specialShapeInfo(),
nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), nullptr, &yBuf, y.shapeInfo(), y.specialShapeInfo(),
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &zBuf, z.shapeInfo(), z.specialShapeInfo(),
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(), &dimBuf, dim.shapeInfo(), dim.specialShapeInfo(),
packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
NDArray::registerSpecialUse({&z}, {&x, &y, &dim}); NDArray::registerSpecialUse({&z}, {&x, &y, &dim});
@ -506,14 +510,17 @@ TEST_F(LegacyOpsTests, Reduce3_3) {
auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {1}); auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {1});
NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); NDArray::prepareSpecialUse({&z}, {&x, &y, &dim});
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer zBuf(z.dataBuffer());
OpaqueDataBuffer dimBuf(dim.dataBuffer());
execReduce3Tad(extraPointers, reduce3::CosineDistance, execReduce3Tad(extraPointers, reduce3::CosineDistance,
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), &xBuf, x.shapeInfo(), x.specialShapeInfo(),
nullptr, nullptr,
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), &yBuf, y.shapeInfo(), y.specialShapeInfo(),
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &zBuf, z.shapeInfo(), z.specialShapeInfo(),
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(), &dimBuf, dim.shapeInfo(), dim.specialShapeInfo(),
packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
ASSERT_EQ(e, z); ASSERT_EQ(e, z);
NDArray::registerSpecialUse({&z}, {&x, &y, &dim}); NDArray::registerSpecialUse({&z}, {&x, &y, &dim});
@ -543,14 +550,17 @@ TEST_F(LegacyOpsTests, Reduce3_4) {
auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {1}); auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {1});
NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); NDArray::prepareSpecialUse({&z}, {&x, &y, &dim});
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer zBuf(z.dataBuffer());
OpaqueDataBuffer dimBuf(dim.dataBuffer());
execReduce3Tad(extraPointers, reduce3::CosineDistance, execReduce3Tad(extraPointers, reduce3::CosineDistance,
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), &xBuf, x.shapeInfo(), x.specialShapeInfo(),
nullptr, nullptr,
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), &yBuf, y.shapeInfo(), y.specialShapeInfo(),
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &zBuf, z.shapeInfo(), z.specialShapeInfo(),
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(), &dimBuf, dim.shapeInfo(), dim.specialShapeInfo(),
packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
// z.printIndexedBuffer("z"); // z.printIndexedBuffer("z");
@ -583,13 +593,16 @@ TEST_F(LegacyOpsTests, Reduce3_5) {
NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); NDArray::prepareSpecialUse({&z}, {&x, &y, &dim});
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer zBuf(z.dataBuffer());
OpaqueDataBuffer dimBuf(dim.dataBuffer());
execReduce3Tad(extraPointers, reduce3::CosineDistance, execReduce3Tad(extraPointers, reduce3::CosineDistance,
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), &xBuf, x.shapeInfo(), x.specialShapeInfo(),
nullptr, nullptr,
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), &yBuf, y.shapeInfo(), y.specialShapeInfo(), &zBuf, z.shapeInfo(), z.specialShapeInfo(),
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &dimBuf, dim.shapeInfo(), dim.specialShapeInfo(),
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
NDArray::registerSpecialUse({&z}, {&x, &y, &dim}); NDArray::registerSpecialUse({&z}, {&x, &y, &dim});
@ -615,10 +628,15 @@ TEST_F(LegacyOpsTests, test_Reduce3_All_1) {
NDArray::prepareSpecialUse({&z}, {&x, &y}); NDArray::prepareSpecialUse({&z}, {&x, &y});
execReduce3All(extraPointers, reduce3::EuclideanDistance, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), OpaqueDataBuffer xBuf(x.dataBuffer());
nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), OpaqueDataBuffer yBuf(y.dataBuffer());
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), OpaqueDataBuffer zBuf(z.dataBuffer());
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(), OpaqueDataBuffer dimBuf(dim.dataBuffer());
execReduce3All(extraPointers, reduce3::EuclideanDistance, &xBuf, x.shapeInfo(), x.specialShapeInfo(),
nullptr, &yBuf, y.shapeInfo(), y.specialShapeInfo(),
&zBuf, z.shapeInfo(), z.specialShapeInfo(),
&dimBuf, dim.shapeInfo(), dim.specialShapeInfo(),
tadPackX.platformShapeInfo(), tadPackX.platformOffsets(), tadPackX.platformShapeInfo(), tadPackX.platformOffsets(),
tadPackY.platformShapeInfo(), tadPackY.platformOffsets()); tadPackY.platformShapeInfo(), tadPackY.platformOffsets());
@ -730,13 +748,16 @@ TEST_F(LegacyOpsTests, test_legacy_reduce_empty_4) {
auto z = NDArrayFactory::create<float>('c', {0, 2}); auto z = NDArrayFactory::create<float>('c', {0, 2});
auto e = NDArrayFactory::create<float>('c', {0, 2}); auto e = NDArrayFactory::create<float>('c', {0, 2});
InteropDataBuffer xdb(x.dataBuffer());
InteropDataBuffer ddb(d.dataBuffer());
InteropDataBuffer zdb(z.dataBuffer());
::execReduceSame2(nullptr, reduce::SameOps::Sum, ::execReduceSame2(nullptr, reduce::SameOps::Sum,
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), &xdb, x.shapeInfo(), x.specialShapeInfo(),
nullptr, nullptr,
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &zdb, z.shapeInfo(), z.specialShapeInfo(),
d.buffer(), d.shapeInfo(), d.specialBuffer(), d.specialShapeInfo()); &ddb, d.shapeInfo(), d.specialShapeInfo());
} }

View File

@ -119,13 +119,15 @@ TEST_F(NativeOpsTests, ExecIndexReduce_1) {
#ifdef __CUDABLAS__ #ifdef __CUDABLAS__
printf("Unsupported for cuda now.\n"); printf("Unsupported for cuda now.\n");
#else #else
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execIndexReduceScalar(nullptr, ::execIndexReduceScalar(nullptr,
indexreduce::IndexMax, indexreduce::IndexMax,
x.buffer(), x.shapeInfo(), &xBuf, x.shapeInfo(),
nullptr, nullptr,
nullptr, nullptr,
exp.buffer(), exp.shapeInfo(), nullptr,
nullptr, nullptr); &expBuf, exp.shapeInfo(),
nullptr);
ASSERT_TRUE(exp.e<Nd4jLong>(0) == 4LL); ASSERT_TRUE(exp.e<Nd4jLong>(0) == 4LL);
#endif #endif
@ -140,15 +142,18 @@ TEST_F(NativeOpsTests, ExecIndexReduce_2) {
printf("Unsupported for cuda now.\n"); printf("Unsupported for cuda now.\n");
#else #else
NDArray dimension = NDArrayFactory::create<int>({}); NDArray dimension = NDArrayFactory::create<int>({});
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
OpaqueDataBuffer dimensionBuf(dimension.dataBuffer());
::execIndexReduce(nullptr, ::execIndexReduce(nullptr,
indexreduce::IndexMax, indexreduce::IndexMax,
x.buffer(), x.shapeInfo(), &xBuf, x.shapeInfo(), nullptr,
nullptr, nullptr,
nullptr, nullptr,
exp.buffer(), exp.shapeInfo(), &expBuf, exp.shapeInfo(),
nullptr, nullptr, nullptr,
dimension.buffer(), dimension.shapeInfo(), &dimensionBuf, dimension.shapeInfo(),
nullptr, nullptr); nullptr);
ASSERT_TRUE(exp.e<Nd4jLong>(0) == 24LL); ASSERT_TRUE(exp.e<Nd4jLong>(0) == 24LL);
#endif #endif
@ -166,16 +171,21 @@ TEST_F(NativeOpsTests, ExecBroadcast_1) {
#else #else
auto dimension = NDArrayFactory::create<int>('c', {1}, {1}); auto dimension = NDArrayFactory::create<int>('c', {1}, {1});
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
OpaqueDataBuffer dimBuf(dimension.dataBuffer());
::execBroadcast(nullptr, ::execBroadcast(nullptr,
broadcast::Add, broadcast::Add,
x.buffer(), x.shapeInfo(), &xBuf, x.shapeInfo(),
nullptr, nullptr, nullptr,
y.buffer(), y.shapeInfo(), &yBuf, y.shapeInfo(),
nullptr, nullptr, nullptr,
exp.buffer(), exp.shapeInfo(), &expBuf, exp.shapeInfo(),
nullptr, nullptr, nullptr,
dimension.buffer(), dimension.shapeInfo(), &dimBuf, dimension.shapeInfo(),
nullptr, nullptr); nullptr);
ASSERT_TRUE(exp.e<float>(0) == 3.); ASSERT_TRUE(exp.e<float>(0) == 3.);
#endif #endif
@ -194,17 +204,18 @@ printf("Unsupported for cuda now.\n");
int dimd = 0; int dimd = 0;
auto dimension = NDArrayFactory::create<int>('c', {1}, {dimd}); auto dimension = NDArrayFactory::create<int>('c', {1}, {dimd});
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
OpaqueDataBuffer dimBuf(dimension.dataBuffer());
::execBroadcastBool(nullptr, ::execBroadcastBool(nullptr,
broadcast::EqualTo, broadcast::EqualTo,
x.buffer(), x.shapeInfo(), &xBuf, x.shapeInfo(), nullptr,
nullptr, nullptr, &yBuf, y.shapeInfo(), nullptr,
y.buffer(), y.shapeInfo(), &expBuf, exp.shapeInfo(), nullptr, nullptr,
nullptr, nullptr, &dimBuf, dimension.shapeInfo(),
exp.buffer(), exp.shapeInfo(), nullptr);
nullptr, nullptr,
nullptr,
dimension.buffer(), dimension.shapeInfo(),
nullptr, nullptr);
ASSERT_TRUE(exp.e<bool>(1) && !exp.e<bool>(0)); ASSERT_TRUE(exp.e<bool>(1) && !exp.e<bool>(0));
#endif #endif
@ -219,14 +230,15 @@ TEST_F(NativeOpsTests, ExecPairwise_1) {
#ifdef __CUDABLAS__ #ifdef __CUDABLAS__
printf("Unsupported for cuda now.\n"); printf("Unsupported for cuda now.\n");
#else #else
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execPairwiseTransform(nullptr, ::execPairwiseTransform(nullptr,
pairwise::Add, pairwise::Add,
x.buffer(), x.shapeInfo(), &xBuf, x.shapeInfo(), nullptr,
nullptr, nullptr, &yBuf, y.shapeInfo(), nullptr,
y.buffer(), y.shapeInfo(), &expBuf, exp.shapeInfo(), nullptr,
nullptr, nullptr,
exp.buffer(), exp.shapeInfo(),
nullptr, nullptr,
nullptr); nullptr);
ASSERT_TRUE(exp.e<float>(5) == 8.); ASSERT_TRUE(exp.e<float>(5) == 8.);
#endif #endif
@ -243,14 +255,15 @@ TEST_F(NativeOpsTests, ExecPairwise_2) {
#ifdef __CUDABLAS__ #ifdef __CUDABLAS__
printf("Unsupported for cuda now.\n"); printf("Unsupported for cuda now.\n");
#else #else
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execPairwiseTransformBool(nullptr, ::execPairwiseTransformBool(nullptr,
pairwise::And, pairwise::And,
x.buffer(), x.shapeInfo(), &xBuf, x.shapeInfo(), nullptr,
nullptr, nullptr, &yBuf, y.shapeInfo(), nullptr,
y.buffer(), y.shapeInfo(), &expBuf, exp.shapeInfo(), nullptr,
nullptr, nullptr,
exp.buffer(), exp.shapeInfo(),
nullptr, nullptr,
nullptr); nullptr);
ASSERT_TRUE(exp.e<bool>(5) && !exp.e<bool>(4)); ASSERT_TRUE(exp.e<bool>(5) && !exp.e<bool>(4));
#endif #endif
@ -266,14 +279,14 @@ TEST_F(NativeOpsTests, ReduceTest_1) {
printf("Unsupported for cuda now.\n"); printf("Unsupported for cuda now.\n");
#else #else
auto dimension = NDArrayFactory::create<int>('c', {1}, {1}); auto dimension = NDArrayFactory::create<int>('c', {1}, {1});
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execReduceFloat(nullptr, ::execReduceFloat(nullptr,
reduce::Mean, reduce::Mean,
x.buffer(), x.shapeInfo(), &xBuf, x.shapeInfo(), nullptr,
nullptr, nullptr,
nullptr, nullptr,
exp.buffer(), exp.shapeInfo(), &expBuf, exp.shapeInfo(), nullptr);
nullptr, nullptr);
// x.printIndexedBuffer("Input"); // x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Reduce Mean"); // exp.printIndexedBuffer("Reduce Mean");
ASSERT_TRUE(exp.e<float>(0) == 13.); ASSERT_TRUE(exp.e<float>(0) == 13.);
@ -289,14 +302,14 @@ TEST_F(NativeOpsTests, ReduceTest_2) {
#ifdef __CUDABLAS__ #ifdef __CUDABLAS__
printf("Unsupported for cuda now.\n"); printf("Unsupported for cuda now.\n");
#else #else
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execReduceSame(nullptr, ::execReduceSame(nullptr,
reduce::Sum, reduce::Sum,
x.buffer(), x.shapeInfo(), &xBuf, x.shapeInfo(), nullptr,
nullptr, nullptr,
nullptr, nullptr,
exp.buffer(), exp.shapeInfo(), &expBuf, exp.shapeInfo(), nullptr);
nullptr, nullptr);
// x.printIndexedBuffer("Input"); // x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Reduce Sum"); // exp.printIndexedBuffer("Reduce Sum");
ASSERT_TRUE(exp.e<float>(0) == 325.); ASSERT_TRUE(exp.e<float>(0) == 325.);
@ -312,14 +325,14 @@ TEST_F(NativeOpsTests, ReduceTest_3) {
#ifdef __CUDABLAS__ #ifdef __CUDABLAS__
printf("Unsupported for cuda now.\n"); printf("Unsupported for cuda now.\n");
#else #else
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execReduceBool(nullptr, ::execReduceBool(nullptr,
reduce::All, reduce::All,
x.buffer(), x.shapeInfo(), &xBuf, x.shapeInfo(), nullptr,
nullptr, nullptr,
nullptr, nullptr,
exp.buffer(), exp.shapeInfo(), &expBuf, exp.shapeInfo(), nullptr);
nullptr, nullptr);
// x.printIndexedBuffer("Input"); // x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Reduce All"); // exp.printIndexedBuffer("Reduce All");
ASSERT_TRUE(exp.e<bool>(0) == true); ASSERT_TRUE(exp.e<bool>(0) == true);
@ -335,14 +348,14 @@ TEST_F(NativeOpsTests, ReduceTest_4) {
#ifdef __CUDABLAS__ #ifdef __CUDABLAS__
printf("Unsupported for cuda now.\n"); printf("Unsupported for cuda now.\n");
#else #else
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execReduceLong(nullptr, ::execReduceLong(nullptr,
reduce::CountNonZero, reduce::CountNonZero,
x.buffer(), x.shapeInfo(), &xBuf, x.shapeInfo(), nullptr,
nullptr, nullptr,
nullptr, nullptr,
exp.buffer(), exp.shapeInfo(), &expBuf, exp.shapeInfo(), nullptr);
nullptr, nullptr);
// x.printIndexedBuffer("Input"); // x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Reduce CountNonZero"); // exp.printIndexedBuffer("Reduce CountNonZero");
ASSERT_TRUE(exp.e<Nd4jLong>(0) == 25LL); ASSERT_TRUE(exp.e<Nd4jLong>(0) == 25LL);
@ -359,15 +372,16 @@ TEST_F(NativeOpsTests, ReduceTest_5) {
printf("Unsupported for cuda now.\n"); printf("Unsupported for cuda now.\n");
#else #else
auto dimension = NDArrayFactory::create<int>({0, 1}); auto dimension = NDArrayFactory::create<int>({0, 1});
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
OpaqueDataBuffer dimBuf(dimension.dataBuffer());
::execReduceLong2(nullptr, ::execReduceLong2(nullptr,
reduce::CountNonZero, reduce::CountNonZero,
x.buffer(), x.shapeInfo(), &xBuf, x.shapeInfo(), x.specialShapeInfo(),
nullptr, nullptr,
nullptr, nullptr,
exp.buffer(), exp.shapeInfo(), &expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
nullptr, nullptr, &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo());
dimension.buffer(), dimension.shapeInfo(),
dimension.specialBuffer(), dimension.specialShapeInfo());
// x.printIndexedBuffer("Input"); // x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Reduce CountNonZero"); // exp.printIndexedBuffer("Reduce CountNonZero");
ASSERT_TRUE(exp.e<Nd4jLong>(0) == 25LL); ASSERT_TRUE(exp.e<Nd4jLong>(0) == 25LL);
@ -389,15 +403,17 @@ TEST_F(NativeOpsTests, ReduceTest_6) {
x.p(10, 0); x.p(11, 0); x.p(10, 0); x.p(11, 0);
x.p(15, 0); x.p(16, 0); x.p(17, 0); x.p(15, 0); x.p(16, 0); x.p(17, 0);
x.p(20, 0); x.p(21, 0); x.p(22, 0); x.p(23, 0); x.p(20, 0); x.p(21, 0); x.p(22, 0); x.p(23, 0);
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer dimBuf(dimension.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execReduceLong2(nullptr, ::execReduceLong2(nullptr,
reduce::CountNonZero, reduce::CountNonZero,
x.buffer(), x.shapeInfo(), &xBuf, x.shapeInfo(), nullptr,
nullptr, nullptr,
nullptr, nullptr,
exp.buffer(), exp.shapeInfo(), &expBuf, exp.shapeInfo(), nullptr,
nullptr, nullptr, &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo());
dimension.buffer(), dimension.shapeInfo(),
dimension.specialBuffer(), dimension.specialShapeInfo());
// x.printIndexedBuffer("Input"); // x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Reduce CountNonZero"); // exp.printIndexedBuffer("Reduce CountNonZero");
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
@ -421,15 +437,16 @@ TEST_F(NativeOpsTests, ReduceTest_7) {
x.linspace(1.0); x.linspace(1.0);
x.syncToDevice(); x.syncToDevice();
dimension.syncToHost(); dimension.syncToHost();
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer dimBuf(dimension.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execReduceFloat2(extra, ::execReduceFloat2(extra,
reduce::Mean, reduce::Mean,
x.buffer(), x.shapeInfo(), &xBuf, x.shapeInfo(), x.specialShapeInfo(),
x.specialBuffer(), x.specialShapeInfo(),
nullptr, nullptr,
exp.buffer(), exp.shapeInfo(), &expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
exp.specialBuffer(), exp.specialShapeInfo(), &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo());
dimension.buffer(), dimension.shapeInfo(),
dimension.specialBuffer(), dimension.specialShapeInfo());
// x.printIndexedBuffer("Input"); // x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Reduce Mean"); // exp.printIndexedBuffer("Reduce Mean");
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
@ -453,16 +470,16 @@ TEST_F(NativeOpsTests, ReduceTest_8) {
x.syncToDevice(); x.syncToDevice();
dimension.syncToHost(); dimension.syncToHost();
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer dimBuf(dimension.dataBuffer());
OpaqueDataBuffer zBuf(z.dataBuffer());
::execReduceSame2(extra, ::execReduceSame2(extra,
reduce::Sum, reduce::Sum,
x.buffer(), x.shapeInfo(), &xBuf, x.shapeInfo(), x.specialShapeInfo(),
x.specialBuffer(), x.specialShapeInfo(),
nullptr, nullptr,
z.buffer(), z.shapeInfo(), &zBuf, z.shapeInfo(), z.specialShapeInfo(),
z.specialBuffer(), z.specialShapeInfo(), &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo());
dimension.buffer(), dimension.shapeInfo(),
dimension.specialBuffer(), dimension.specialShapeInfo());
// x.printIndexedBuffer("Input"); // x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Reduce Sum"); // exp.printIndexedBuffer("Reduce Sum");
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
@ -485,15 +502,17 @@ TEST_F(NativeOpsTests, ReduceTest_9) {
x.syncToDevice(); x.syncToDevice();
dimension.syncToHost(); dimension.syncToHost();
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer dimBuf(dimension.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execReduceBool2(extra, ::execReduceBool2(extra,
reduce::All, reduce::All,
x.buffer(), x.shapeInfo(), &xBuf, x.shapeInfo(), x.specialShapeInfo(),
x.specialBuffer(), x.specialShapeInfo(),
nullptr, nullptr,
exp.buffer(), exp.shapeInfo(), &expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
exp.specialBuffer(), exp.specialShapeInfo(), &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo());
dimension.buffer(), dimension.shapeInfo(),
dimension.specialBuffer(), dimension.specialShapeInfo());
// x.printIndexedBuffer("Input"); // x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Reduce All"); // exp.printIndexedBuffer("Reduce All");
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
@ -518,15 +537,16 @@ TEST_F(NativeOpsTests, Reduce3Test_1) {
y.assign(2.); y.assign(2.);
x.syncToDevice(); x.syncToDevice();
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execReduce3(extra, ::execReduce3(extra,
reduce3::Dot, reduce3::Dot,
x.buffer(), x.shapeInfo(), &xBuf, x.shapeInfo(), x.specialShapeInfo(),
x.specialBuffer(), x.specialShapeInfo(),
nullptr, nullptr,
y.buffer(), y.shapeInfo(), &yBuf, y.shapeInfo(), y.specialShapeInfo(),
y.specialBuffer(), y.specialShapeInfo(), &expBuf, exp.shapeInfo(), exp.specialShapeInfo());
exp.buffer(), exp.shapeInfo(),
exp.specialBuffer(), exp.specialShapeInfo());
//z.printIndexedBuffer("Z"); //z.printIndexedBuffer("Z");
//exp.printIndexedBuffer("Reduce3 Dot"); //exp.printIndexedBuffer("Reduce3 Dot");
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
@ -551,15 +571,16 @@ TEST_F(NativeOpsTests, Reduce3Test_2) {
y.assign(2.); y.assign(2.);
x.syncToDevice(); x.syncToDevice();
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execReduce3Scalar(extra, ::execReduce3Scalar(extra,
reduce3::Dot, reduce3::Dot,
x.buffer(), x.shapeInfo(), &xBuf, x.shapeInfo(), x.specialShapeInfo(),
x.specialBuffer(), x.specialShapeInfo(),
nullptr, nullptr,
y.buffer(), y.shapeInfo(), &yBuf, y.shapeInfo(), y.specialShapeInfo(),
y.specialBuffer(), y.specialShapeInfo(), &expBuf, exp.shapeInfo(), exp.specialShapeInfo());
exp.buffer(), exp.shapeInfo(),
exp.specialBuffer(), exp.specialShapeInfo());
// x.printIndexedBuffer("Input"); // x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Reduce3 Dot"); // exp.printIndexedBuffer("Reduce3 Dot");
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
@ -585,17 +606,18 @@ TEST_F(NativeOpsTests, Reduce3Test_3) {
x.syncToDevice(); x.syncToDevice();
dimension.syncToHost(); dimension.syncToHost();
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
OpaqueDataBuffer dimBuf(dimension.dataBuffer());
::execReduce3Tad(extra, ::execReduce3Tad(extra,
reduce3::Dot, reduce3::Dot,
x.buffer(), x.shapeInfo(), &xBuf, x.shapeInfo(), x.specialShapeInfo(),
x.specialBuffer(), x.specialShapeInfo(),
nullptr, nullptr,
y.buffer(), y.shapeInfo(), &yBuf, y.shapeInfo(), y.specialShapeInfo(),
y.specialBuffer(), y.specialShapeInfo(), &expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
exp.buffer(), exp.shapeInfo(), &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo(),
exp.specialBuffer(), exp.specialShapeInfo(),
dimension.buffer(), dimension.shapeInfo(),
dimension.specialBuffer(), dimension.specialShapeInfo(),
nullptr, nullptr, nullptr, nullptr); nullptr, nullptr, nullptr, nullptr);
// x.printIndexedBuffer("Input"); // x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Reduce All"); // exp.printIndexedBuffer("Reduce All");
@ -630,17 +652,18 @@ TEST_F(NativeOpsTests, Reduce3Test_4) {
auto hTADShapeInfoY = tadPackY.primaryShapeInfo(); auto hTADShapeInfoY = tadPackY.primaryShapeInfo();
auto hTADOffsetsY = tadPackY.primaryOffsets(); auto hTADOffsetsY = tadPackY.primaryOffsets();
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
OpaqueDataBuffer dimBuf(dimension.dataBuffer());
::execReduce3All(extra, ::execReduce3All(extra,
reduce3::Dot, reduce3::Dot,
x.buffer(), x.shapeInfo(), &xBuf, x.shapeInfo(), x.specialShapeInfo(),
x.specialBuffer(), x.specialShapeInfo(),
nullptr, nullptr,
y.buffer(), y.shapeInfo(), &yBuf, y.shapeInfo(), y.specialShapeInfo(),
y.specialBuffer(), y.specialShapeInfo(), &expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
exp.buffer(), exp.shapeInfo(), &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo(),
exp.specialBuffer(), exp.specialShapeInfo(),
dimension.buffer(), dimension.shapeInfo(),
dimension.specialBuffer(), dimension.specialShapeInfo(),
hTADShapeInfoX, hTADOffsetsX, hTADShapeInfoY, hTADOffsetsY); hTADShapeInfoX, hTADOffsetsX, hTADShapeInfoY, hTADOffsetsY);
// x.printIndexedBuffer("Input"); // x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Reduce All"); // exp.printIndexedBuffer("Reduce All");
@ -667,14 +690,16 @@ TEST_F(NativeOpsTests, ScalarTest_1) {
//y.assign(2.); //y.assign(2.);
x.syncToDevice(); x.syncToDevice();
z.syncToDevice(); z.syncToDevice();
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execScalar(extra, ::execScalar(extra,
scalar::Multiply, scalar::Multiply,
x.buffer(), x.shapeInfo(), &xBuf, x.shapeInfo(), x.specialShapeInfo(),
x.specialBuffer(), x.specialShapeInfo(), &expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
exp.buffer(), exp.shapeInfo(), &yBuf, y.shapeInfo(), y.specialShapeInfo(), nullptr);
exp.specialBuffer(), exp.specialShapeInfo(),
y.buffer(), y.shapeInfo(),
y.specialBuffer(), y.specialShapeInfo(), nullptr);
// x.printIndexedBuffer("Input"); // x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Reduce All"); // exp.printIndexedBuffer("Reduce All");
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
@ -700,14 +725,16 @@ TEST_F(NativeOpsTests, ScalarTest_2) {
//y.assign(2.); //y.assign(2.);
x.syncToDevice(); x.syncToDevice();
z.syncToDevice(); z.syncToDevice();
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execScalarBool(extra, ::execScalarBool(extra,
scalar::GreaterThan, scalar::GreaterThan,
x.buffer(), x.shapeInfo(), &xBuf, x.shapeInfo(), x.specialShapeInfo(),
x.specialBuffer(), x.specialShapeInfo(), &expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
exp.buffer(), exp.shapeInfo(), &yBuf, y.shapeInfo(), y.specialShapeInfo(), nullptr);
exp.specialBuffer(), exp.specialShapeInfo(),
y.buffer(), y.shapeInfo(),
y.specialBuffer(), y.specialShapeInfo(), nullptr);
// x.printIndexedBuffer("Input"); // x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Reduce All"); // exp.printIndexedBuffer("Reduce All");
ASSERT_TRUE(exp.e<bool>(5) == z.e<bool>(5) && exp.e<bool>(15) != z.e<bool>(15)); ASSERT_TRUE(exp.e<bool>(5) == z.e<bool>(5) && exp.e<bool>(15) != z.e<bool>(15));
@ -726,13 +753,14 @@ TEST_F(NativeOpsTests, SummaryStatsScalarTest_1) {
printf("Unsupported for CUDA platform yet.\n"); printf("Unsupported for CUDA platform yet.\n");
return; return;
#endif #endif
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execSummaryStatsScalar(extra, ::execSummaryStatsScalar(extra,
variance::SummaryStatsVariance, variance::SummaryStatsVariance,
x.buffer(), x.shapeInfo(), &xBuf, x.shapeInfo(), x.specialShapeInfo(),
x.specialBuffer(), x.specialShapeInfo(),
nullptr, nullptr,
exp.buffer(), exp.shapeInfo(), &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), false);
exp.specialBuffer(), exp.specialShapeInfo(), false);
// x.printIndexedBuffer("Input"); // x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Standard Variance"); // exp.printIndexedBuffer("Standard Variance");
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
@ -751,13 +779,13 @@ TEST_F(NativeOpsTests, SummaryStatsScalarTest_2) {
printf("Unsupported for CUDA platform yet.\n"); printf("Unsupported for CUDA platform yet.\n");
return; return;
#endif #endif
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execSummaryStats(extra, ::execSummaryStats(extra,
variance::SummaryStatsVariance, variance::SummaryStatsVariance,
x.buffer(), x.shapeInfo(), &xBuf, x.shapeInfo(), x.specialShapeInfo(),
x.specialBuffer(), x.specialShapeInfo(),
nullptr, nullptr,
exp.buffer(), exp.shapeInfo(), &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), false);
exp.specialBuffer(), exp.specialShapeInfo(), false);
// x.printIndexedBuffer("Input"); // x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Standard Variance"); // exp.printIndexedBuffer("Standard Variance");
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
@ -777,15 +805,16 @@ TEST_F(NativeOpsTests, SummaryStatsScalarTest_3) {
return; return;
#endif #endif
auto dimensions = NDArrayFactory::create<int>({0, 1}); auto dimensions = NDArrayFactory::create<int>({0, 1});
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
OpaqueDataBuffer dimBuf(dimensions.dataBuffer());
::execSummaryStatsTad(extra, ::execSummaryStatsTad(extra,
variance::SummaryStatsVariance, variance::SummaryStatsVariance,
x.buffer(), x.shapeInfo(), &xBuf, x.shapeInfo(), x.specialShapeInfo(),
x.specialBuffer(), x.specialShapeInfo(),
nullptr, nullptr,
exp.buffer(), exp.shapeInfo(), &expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
exp.specialBuffer(), exp.specialShapeInfo(), &dimBuf, dimensions.shapeInfo(), dimensions.specialShapeInfo(),
dimensions.buffer(), dimensions.shapeInfo(),
dimensions.specialBuffer(), dimensions.specialShapeInfo(),
false, false,
nullptr, nullptr); nullptr, nullptr);
// x.printIndexedBuffer("Input"); // x.printIndexedBuffer("Input");
@ -807,13 +836,15 @@ TEST_F(NativeOpsTests, TransformTest_1) {
return; return;
#endif #endif
z.linspace(1.); z.linspace(1.);
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer zBuf(z.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execTransformFloat(extra, ::execTransformFloat(extra,
transform::Sqrt, transform::Sqrt,
x.buffer(), x.shapeInfo(), &xBuf, x.shapeInfo(), x.specialShapeInfo(),
x.specialBuffer(), x.specialShapeInfo(), &expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
exp.buffer(), exp.shapeInfo(),
exp.specialBuffer(), exp.specialShapeInfo(),
nullptr); nullptr);
// x.printIndexedBuffer("Input"); // x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Sqrt is"); // exp.printIndexedBuffer("Sqrt is");
@ -834,13 +865,15 @@ TEST_F(NativeOpsTests, TransformTest_2) {
return; return;
#endif #endif
z.linspace(1.); z.linspace(1.);
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer zBuf(z.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execTransformSame(extra, ::execTransformSame(extra,
transform::Square, transform::Square,
z.buffer(), z.shapeInfo(), &zBuf, z.shapeInfo(), z.specialShapeInfo(),
z.specialBuffer(), z.specialShapeInfo(), &expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
exp.buffer(), exp.shapeInfo(),
exp.specialBuffer(), exp.specialShapeInfo(),
nullptr); nullptr);
// x.printIndexedBuffer("Input"); // x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Square is"); // exp.printIndexedBuffer("Square is");
@ -864,13 +897,14 @@ TEST_F(NativeOpsTests, TransformTest_3) {
z.assign(true); z.assign(true);
x.p(24, -25); x.p(24, -25);
z.p(24, false); z.p(24, false);
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execTransformBool(extra, ::execTransformBool(extra,
transform::IsPositive, transform::IsPositive,
x.buffer(), x.shapeInfo(), &xBuf, x.shapeInfo(), x.specialShapeInfo(),
x.specialBuffer(), x.specialShapeInfo(), &expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
exp.buffer(), exp.shapeInfo(),
exp.specialBuffer(), exp.specialShapeInfo(),
nullptr); nullptr);
// x.printIndexedBuffer("Input"); // x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("IsPositive"); // exp.printIndexedBuffer("IsPositive");
@ -894,13 +928,13 @@ TEST_F(NativeOpsTests, TransformTest_4) {
return; return;
#endif #endif
//z.linspace(1.); //z.linspace(1.);
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execTransformStrict(extra, ::execTransformStrict(extra,
transform::Cosine, transform::Cosine,
x.buffer(), x.shapeInfo(), &xBuf, x.shapeInfo(), x.specialShapeInfo(),
x.specialBuffer(), x.specialShapeInfo(), &expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
exp.buffer(), exp.shapeInfo(),
exp.specialBuffer(), exp.specialShapeInfo(),
nullptr); nullptr);
// x.printIndexedBuffer("Input"); // x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Cosine"); // exp.printIndexedBuffer("Cosine");
@ -932,17 +966,18 @@ TEST_F(NativeOpsTests, ScalarTadTest_1) {
auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), dimensions, dimension.lengthOf()); auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), dimensions, dimension.lengthOf());
auto tadPackZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.shapeInfo(), dimensions, dimension.lengthOf()); auto tadPackZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.shapeInfo(), dimensions, dimension.lengthOf());
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
OpaqueDataBuffer dimBuf(dimension.dataBuffer());
::execScalarTad(extra, ::execScalarTad(extra,
scalar::Multiply, scalar::Multiply,
x.buffer(), x.shapeInfo(), &xBuf, x.shapeInfo(), x.specialShapeInfo(),
x.specialBuffer(), x.specialShapeInfo(), &expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
exp.buffer(), exp.shapeInfo(), &yBuf, y.shapeInfo(), y.specialShapeInfo(),
exp.specialBuffer(), exp.specialShapeInfo(),
y.buffer(), y.shapeInfo(),
y.specialBuffer(), y.specialShapeInfo(),
nullptr, nullptr,
dimension.buffer(), dimension.shapeInfo(), &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo(),
dimension.specialBuffer(), dimension.specialShapeInfo(),
tadPackX.primaryShapeInfo(), tadPackX.primaryOffsets(), tadPackZ.primaryShapeInfo(), tadPackZ.primaryOffsets()); tadPackX.primaryShapeInfo(), tadPackX.primaryOffsets(), tadPackZ.primaryShapeInfo(), tadPackZ.primaryOffsets());
// x.printIndexedBuffer("Input"); // x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Reduce All"); // exp.printIndexedBuffer("Reduce All");
@ -977,17 +1012,21 @@ TEST_F(NativeOpsTests, ScalarTadTest_2) {
auto tadPackZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.shapeInfo(), dimensions, dimension.lengthOf()); auto tadPackZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.shapeInfo(), dimensions, dimension.lengthOf());
z.assign(true); z.assign(true);
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
OpaqueDataBuffer dimBuf(dimension.dataBuffer());
::execScalarBoolTad(extra, ::execScalarBoolTad(extra,
scalar::And, scalar::And,
x.buffer(), x.shapeInfo(), &xBuf, x.shapeInfo(), x.specialShapeInfo(),
x.specialBuffer(), x.specialShapeInfo(), &expBuf, exp.shapeInfo(),
exp.buffer(), exp.shapeInfo(), exp.specialShapeInfo(),
exp.specialBuffer(), exp.specialShapeInfo(), &yBuf, y.shapeInfo(),
y.buffer(), y.shapeInfo(), y.specialShapeInfo(),
y.specialBuffer(), y.specialShapeInfo(),
nullptr, nullptr,
dimension.buffer(), dimension.shapeInfo(), &dimBuf, dimension.shapeInfo(),
dimension.specialBuffer(), dimension.specialShapeInfo(), dimension.specialShapeInfo(),
tadPackX.primaryShapeInfo(), tadPackX.primaryOffsets(), tadPackZ.primaryShapeInfo(), tadPackZ.primaryOffsets()); tadPackX.primaryShapeInfo(), tadPackX.primaryOffsets(), tadPackZ.primaryShapeInfo(), tadPackZ.primaryOffsets());
// x.printIndexedBuffer("Input"); // x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("And"); // exp.printIndexedBuffer("And");
@ -1095,9 +1134,11 @@ TEST_F(NativeOpsTests, PullRowsTest_1) {
#ifdef __CUDABLAS__ #ifdef __CUDABLAS__
nativeStart[1] = (x.getContext()->getCudaStream()); nativeStart[1] = (x.getContext()->getCudaStream());
#endif #endif
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer zBuf(z.dataBuffer());
pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), pullRows(nativeStart, &xBuf, x.getShapeInfo(), x.getSpecialShapeInfo(),
z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &zBuf, z.getShapeInfo(), z.specialShapeInfo(),
4, pidx, 4, pidx,
xTadPack.platformShapeInfo(), xTadPack.platformOffsets(), xTadPack.platformShapeInfo(), xTadPack.platformOffsets(),
zTadPack.platformShapeInfo(), zTadPack.platformOffsets()); zTadPack.platformShapeInfo(), zTadPack.platformOffsets());
@ -1250,7 +1291,9 @@ TEST_F(NativeOpsTests, RandomTest_1) {
#endif #endif
graph::RandomGenerator rng(1023, 119); graph::RandomGenerator rng(1023, 119);
double p = 0.5; double p = 0.5;
::execRandom(extra, random::BernoulliDistribution, &rng, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &p); OpaqueDataBuffer zBuf(z.dataBuffer());
::execRandom(extra, random::BernoulliDistribution, &rng, &zBuf, z.shapeInfo(), z.specialShapeInfo(), &p);
} }
TEST_F(NativeOpsTests, RandomTest_2) { TEST_F(NativeOpsTests, RandomTest_2) {
@ -1264,7 +1307,10 @@ TEST_F(NativeOpsTests, RandomTest_2) {
x.linspace(0, 0.01); x.linspace(0, 0.01);
graph::RandomGenerator rng(1023, 119); graph::RandomGenerator rng(1023, 119);
double p = 0.5; double p = 0.5;
::execRandom2(extra, random::DropOut, &rng, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &p); OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer zBuf(z.dataBuffer());
::execRandom2(extra, random::DropOut, &rng, &xBuf, x.shapeInfo(), x.specialShapeInfo(), &zBuf, z.shapeInfo(), z.specialShapeInfo(), &p);
} }
TEST_F(NativeOpsTests, RandomTest_3) { TEST_F(NativeOpsTests, RandomTest_3) {
@ -1280,7 +1326,12 @@ TEST_F(NativeOpsTests, RandomTest_3) {
x.linspace(1, -0.01); x.linspace(1, -0.01);
graph::RandomGenerator rng(1023, 119); graph::RandomGenerator rng(1023, 119);
double p = 0.5; double p = 0.5;
::execRandom3(extra, random::ProbablisticMerge, &rng, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &p); OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer zBuf(z.dataBuffer());
::execRandom3(extra, random::ProbablisticMerge, &rng, &xBuf, x.shapeInfo(), x.specialShapeInfo(), &yBuf,
y.shapeInfo(), y.specialShapeInfo(), &zBuf, z.shapeInfo(), z.specialShapeInfo(), &p);
} }
TEST_F(NativeOpsTests, RandomTest_4) { TEST_F(NativeOpsTests, RandomTest_4) {
@ -1316,6 +1367,10 @@ TEST_F(NativeOpsTests, SortTests_2) {
#ifdef __CUDABLAS__ #ifdef __CUDABLAS__
extras[1] = LaunchContext::defaultContext()->getCudaStream(); extras[1] = LaunchContext::defaultContext()->getCudaStream();
#endif #endif
// OpaqueDataBuffer xBuf(x.dataBuffer());
// OpaqueDataBuffer yBuf(y.dataBuffer());
// OpaqueDataBuffer expBuf(exp.dataBuffer());
// OpaqueDataBuffer dimBuf(exp.dataBuffer());
::sortByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false); ::sortByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
k.tickWriteDevice(); k.tickWriteDevice();
@ -1541,6 +1596,13 @@ TEST_F(NativeOpsTests, CalculateOutputShapeTests_2) {
::deleteShapeList((Nd4jPointer) shapeList); ::deleteShapeList((Nd4jPointer) shapeList);
} }
TEST_F(NativeOpsTests, interop_databuffer_tests_1) {
auto idb = ::allocateDataBuffer(100, 10, false);
auto ptr = ::dbPrimaryBuffer(idb);
::deleteDataBuffer(idb);
}
//Uncomment when needed only - massive calculations //Uncomment when needed only - massive calculations
//TEST_F(NativeOpsTests, BenchmarkTests_1) { //TEST_F(NativeOpsTests, BenchmarkTests_1) {
// //

View File

@ -91,3 +91,25 @@ TEST_F(StringTests, Basic_dup_1) {
delete dup; delete dup;
} }
TEST_F(StringTests, byte_length_test_1) {
std::string f("alpha");
auto array = NDArrayFactory::string(f);
ASSERT_EQ(f.length(), StringUtils::byteLength(array));
}
TEST_F(StringTests, byte_length_test_2) {
auto array = NDArrayFactory::string('c', {2}, {"alpha", "beta"});
ASSERT_EQ(9, StringUtils::byteLength(array));
}
TEST_F(StringTests, test_split_1) {
auto split = StringUtils::split("alpha beta gamma", " ");
ASSERT_EQ(3, split.size());
ASSERT_EQ(std::string("alpha"), split[0]);
ASSERT_EQ(std::string("beta"), split[1]);
ASSERT_EQ(std::string("gamma"), split[2]);
}

View File

@ -1,5 +1,6 @@
package org.nd4j.autodiff.listeners.debugging; package org.nd4j.autodiff.listeners.debugging;
import lombok.val;
import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.listeners.At; import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener; import org.nd4j.autodiff.listeners.BaseListener;
@ -113,16 +114,16 @@ public class ExecDebuggingListener extends BaseListener {
if(co.tArgs() != null && co.tArgs().length > 0) { if(co.tArgs() != null && co.tArgs().length > 0) {
sb.append("\n\ttArgs=").append(Arrays.toString(co.tArgs())); sb.append("\n\ttArgs=").append(Arrays.toString(co.tArgs()));
} }
INDArray[] inputs = co.inputArguments(); val inputs = co.inputArguments();
INDArray[] outputs = co.outputArguments(); val outputs = co.outputArguments();
if(inputs != null ) { if(inputs != null ) {
for (int i = 0; i < inputs.length; i++) { for (int i = 0; i < inputs.size(); i++) {
sb.append("\n\tInput[").append(i).append("]=").append(inputs[i].shapeInfoToString()); sb.append("\n\tInput[").append(i).append("]=").append(inputs.get(i).shapeInfoToString());
} }
} }
if(outputs != null ) { if(outputs != null ) {
for (int i = 0; i < outputs.length; i++) { for (int i = 0; i < outputs.size(); i++) {
sb.append("\n\tOutputs[").append(i).append("]=").append(outputs[i].shapeInfoToString()); sb.append("\n\tOutputs[").append(i).append("]=").append(outputs.get(i).shapeInfoToString());
} }
} }
} else { } else {
@ -156,22 +157,22 @@ public class ExecDebuggingListener extends BaseListener {
if(co.tArgs() != null && co.tArgs().length > 0 ){ if(co.tArgs() != null && co.tArgs().length > 0 ){
sb.append("op.addTArgument(").append(Arrays.toString(co.tArgs()).replaceAll("[\\[\\]]", "")).append(");\n"); sb.append("op.addTArgument(").append(Arrays.toString(co.tArgs()).replaceAll("[\\[\\]]", "")).append(");\n");
} }
INDArray[] inputs = co.inputArguments(); val inputs = co.inputArguments();
INDArray[] outputs = co.outputArguments(); val outputs = co.outputArguments();
if(inputs != null ) { if(inputs != null ) {
sb.append("INDArray[] inputs = new INDArray[").append(inputs.length).append("];\n"); sb.append("INDArray[] inputs = new INDArray[").append(inputs.size()).append("];\n");
for (int i = 0; i < inputs.length; i++) { for (int i = 0; i < inputs.size(); i++) {
sb.append("inputs[").append(i).append("] = "); sb.append("inputs[").append(i).append("] = ");
sb.append(createString(inputs[i])) sb.append(createString(inputs.get(i)))
.append(";\n"); .append(";\n");
} }
sb.append("op.addInputArgument(inputs);\n"); sb.append("op.addInputArgument(inputs);\n");
} }
if(outputs != null ) { if(outputs != null ) {
sb.append("INDArray[] outputs = new INDArray[").append(outputs.length).append("];\n"); sb.append("INDArray[] outputs = new INDArray[").append(outputs.size()).append("];\n");
for (int i = 0; i < outputs.length; i++) { for (int i = 0; i < outputs.size(); i++) {
sb.append("outputs[").append(i).append("] = "); sb.append("outputs[").append(i).append("] = ");
sb.append(createString(outputs[i])) sb.append(createString(outputs.get(i)))
.append(";\n"); .append(";\n");
} }
sb.append("op.addOutputArgument(outputs);\n"); sb.append("op.addOutputArgument(outputs);\n");

View File

@ -478,11 +478,11 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
} }
throw new IllegalStateException(s); throw new IllegalStateException(s);
} }
return ((Assert) op).outputArguments(); return ((Assert) op).outputArguments().toArray(new INDArray[0]);
} else if (op instanceof CustomOp) { } else if (op instanceof CustomOp) {
CustomOp c = (CustomOp) op; CustomOp c = (CustomOp) op;
Nd4j.exec(c); Nd4j.exec(c);
return c.outputArguments(); return c.outputArguments().toArray(new INDArray[0]);
} else if (op instanceof Op) { } else if (op instanceof Op) {
Op o = (Op) op; Op o = (Op) op;
Nd4j.exec(o); Nd4j.exec(o);

View File

@ -457,7 +457,7 @@ public class OpValidation {
for (int i = 0; i < testCase.testFns().size(); i++) { for (int i = 0; i < testCase.testFns().size(); i++) {
String error; String error;
try { try {
error = testCase.testFns().get(i).apply(testCase.op().outputArguments()[i]); error = testCase.testFns().get(i).apply(testCase.op().outputArguments().get(i));
} catch (Throwable t) { } catch (Throwable t) {
throw new IllegalStateException("Exception thrown during op output validation for output " + i, t); throw new IllegalStateException("Exception thrown during op output validation for output " + i, t);
} }

View File

@ -1,6 +1,7 @@
package org.nd4j.autodiff.validation.listeners; package org.nd4j.autodiff.validation.listeners;
import lombok.Getter; import lombok.Getter;
import lombok.val;
import org.nd4j.autodiff.listeners.At; import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener; import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.Operation; import org.nd4j.autodiff.listeners.Operation;
@ -50,12 +51,12 @@ public class NonInplaceValidationListener extends BaseListener {
opInputs = new INDArray[]{o.x().dup(), o.y().dup()}; opInputs = new INDArray[]{o.x().dup(), o.y().dup()};
} }
} else if(op.getOp() instanceof DynamicCustomOp){ } else if(op.getOp() instanceof DynamicCustomOp){
INDArray[] arr = ((DynamicCustomOp) op.getOp()).inputArguments(); val arr = ((DynamicCustomOp) op.getOp()).inputArguments();
opInputs = new INDArray[arr.length]; opInputs = new INDArray[arr.size()];
opInputsOrig = new INDArray[arr.length]; opInputsOrig = new INDArray[arr.size()];
for( int i=0; i<arr.length; i++ ){ for( int i=0; i<arr.size(); i++ ){
opInputsOrig[i] = arr[i]; opInputsOrig[i] = arr.get(i);
opInputs[i] = arr[i].dup(); opInputs[i] = arr.get(i).dup();
} }
} else { } else {
throw new IllegalStateException("Unknown op type: " + op.getOp().getClass()); throw new IllegalStateException("Unknown op type: " + op.getOp().getClass());

View File

@ -589,6 +589,10 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.random.impl.Range.class, org.nd4j.linalg.api.ops.random.impl.Range.class,
org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution.class, org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution.class,
org.nd4j.linalg.api.ops.random.impl.UniformDistribution.class, org.nd4j.linalg.api.ops.random.impl.UniformDistribution.class,
org.nd4j.linalg.api.ops.util.PrintAffinity.class,
org.nd4j.linalg.api.ops.util.PrintVariable.class,
org.nd4j.linalg.api.ops.compat.CompatSparseToDense.class,
org.nd4j.linalg.api.ops.compat.CompatStringSplit.class,
org.nd4j.linalg.api.ops.custom.AdjustContrast.class, org.nd4j.linalg.api.ops.custom.AdjustContrast.class,
org.nd4j.linalg.api.ops.custom.AdjustContrastV2.class, org.nd4j.linalg.api.ops.custom.AdjustContrastV2.class,
org.nd4j.linalg.api.ops.custom.HsvToRgb.class, org.nd4j.linalg.api.ops.custom.HsvToRgb.class,

View File

@ -73,7 +73,7 @@ public class ActivationPReLU extends BaseActivationFunction {
preluBp.addIntegerArguments(axis); preluBp.addIntegerArguments(axis);
} }
} }
Nd4j.getExecutioner().execAndReturn(preluBp.build()); Nd4j.exec(preluBp.build());
in.assign(outTemp); in.assign(outTemp);
return new Pair<>(in, dLdalpha); return new Pair<>(in, dLdalpha);
} }

View File

@ -23,7 +23,6 @@ import com.google.flatbuffers.FlatBufferBuilder;
import lombok.NonNull; import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import net.ericaro.neoitertools.Generator;
import org.apache.commons.math3.util.FastMath; import org.apache.commons.math3.util.FastMath;
import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.BytePointer;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper; import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
@ -998,14 +997,14 @@ public abstract class BaseNDArray implements INDArray, Iterable {
} }
} }
Pair<DataBuffer, DataBuffer> tadInfo = Pair<DataBuffer, DataBuffer> tadInfo = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(this, dimension);
Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(this, dimension);
DataBuffer shapeInfo = tadInfo.getFirst(); DataBuffer shapeInfo = tadInfo.getFirst();
val shape = Shape.shape(shapeInfo); val jShapeInfo = shapeInfo.asLong();
val stride = Shape.stride(shapeInfo).asLong(); val shape = Shape.shape(jShapeInfo);
val stride = Shape.stride(jShapeInfo);
long offset = offset() + tadInfo.getSecond().getLong(index); long offset = offset() + tadInfo.getSecond().getLong(index);
val ews = shapeInfo.getLong(shapeInfo.getLong(0) * 2 + 2); val ews = shapeInfo.getLong(jShapeInfo[0] * 2 + 2);
char tadOrder = (char) shapeInfo.getInt(shapeInfo.getLong(0) * 2 + 3); char tadOrder = (char) shapeInfo.getInt(jShapeInfo[0] * 2 + 3);
val toTad = Nd4j.create(data(), shape, stride, offset, ews, tadOrder); val toTad = Nd4j.create(data(), shape, stride, offset, ews, tadOrder);
return toTad; return toTad;
} }
@ -2217,9 +2216,10 @@ public abstract class BaseNDArray implements INDArray, Iterable {
if(isEmpty() || isS()) if(isEmpty() || isS())
return false; return false;
return Shape.offset(jvmShapeInfo.javaShapeInformation) > 0 val c2 = (length() < data().length() && data.dataType() != DataType.INT);
|| (length() < data().length() && data.dataType() != DataType.INT) val c3 = (data().originalDataBuffer() != null && data != data.originalDataBuffer());
|| data().originalDataBuffer() != null;
return c2 || c3;
} }
@Override @Override
@ -3585,6 +3585,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
case DOUBLE: case DOUBLE:
case FLOAT: case FLOAT:
case HALF: case HALF:
case BFLOAT16:
return getDouble(i); return getDouble(i);
case LONG: case LONG:
case INT: case INT:
@ -3592,6 +3593,9 @@ public abstract class BaseNDArray implements INDArray, Iterable {
case UBYTE: case UBYTE:
case BYTE: case BYTE:
case BOOL: case BOOL:
case UINT64:
case UINT32:
case UINT16:
return getLong(i); return getLong(i);
case UTF8: case UTF8:
case COMPRESSED: case COMPRESSED:
@ -4350,29 +4354,30 @@ public abstract class BaseNDArray implements INDArray, Iterable {
//epsilon equals //epsilon equals
if (isScalar() && n.isScalar()) { if (isScalar() && n.isScalar()) {
if (data.dataType() == DataType.FLOAT) { if (isZ()) {
double val = getDouble(0); val val = getLong(0);
double val2 = n.getDouble(0); val val2 = n.getLong(0);
return val == val2;
} else if (isR()) {
val val = getDouble(0);
val val2 = n.getDouble(0);
if (Double.isNaN(val) != Double.isNaN(val2)) if (Double.isNaN(val) != Double.isNaN(val2))
return false; return false;
return Math.abs(val - val2) < eps; return Math.abs(val - val2) < eps;
} else { } else if (isB()) {
double val = getDouble(0); val val = getInt(0);
double val2 = n.getDouble(0); val val2 = n.getInt(0);
if (Double.isNaN(val) != Double.isNaN(val2)) return val == val2;
return false;
return Math.abs(val - val2) < eps;
} }
} else if (isVector() && n.isVector()) { } else if (isVector() && n.isVector()) {
val op = new EqualsWithEps(this, n, eps);
EqualsWithEps op = new EqualsWithEps(this, n, eps); Nd4j.exec(op);
Nd4j.getExecutioner().exec(op); val diff = op.z().getDouble(0);
double diff = op.z().getDouble(0);
return diff < 0.5; return diff < 0.5;
} }
@ -4750,8 +4755,8 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return this; return this;
checkArrangeArray(rearrange); checkArrangeArray(rearrange);
int[] newShape = doPermuteSwap(shapeOf(), rearrange); val newShape = doPermuteSwap(shape(), rearrange);
int[] newStride = doPermuteSwap(strideOf(), rearrange); val newStride = doPermuteSwap(stride(), rearrange);
char newOrder = Shape.getOrder(newShape, newStride, 1); char newOrder = Shape.getOrder(newShape, newStride, 1);
@ -4777,23 +4782,11 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return this; return this;
checkArrangeArray(rearrange); checkArrangeArray(rearrange);
val newShape = doPermuteSwap(Shape.shapeOf(shapeInfo), rearrange); val newShape = doPermuteSwap(shape(), rearrange);
val newStride = doPermuteSwap(Shape.stride(shapeInfo), rearrange); val newStride = doPermuteSwap(stride(), rearrange);
char newOrder = Shape.getOrder(newShape, newStride, 1); char newOrder = Shape.getOrder(newShape, newStride, 1);
//Set the shape information of this array: shape, stride, order.
//Shape info buffer: [rank, [shape], [stride], offset, elementwiseStride, order]
/*for( int i=0; i<rank; i++ ){
shapeInfo.put(1+i,newShape[i]);
shapeInfo.put(1+i+rank,newStride[i]);
}
shapeInfo.put(3+2*rank,newOrder);
*/
val ews = shapeInfo.get(2 * rank + 2); val ews = shapeInfo.get(2 * rank + 2);
/*
if (ews < 1 && !attemptedToFindElementWiseStride)
throw new RuntimeException("EWS is -1");
*/
val si = Nd4j.getShapeInfoProvider().createShapeInformation(newShape, newStride, ews, newOrder, dataType(), isEmpty()); val si = Nd4j.getShapeInfoProvider().createShapeInformation(newShape, newStride, ews, newOrder, dataType(), isEmpty());
setShapeInformation(si); setShapeInformation(si);
@ -4813,6 +4806,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
} }
@Deprecated
protected long[] doPermuteSwap(LongBuffer shape, int[] rearrange) { protected long[] doPermuteSwap(LongBuffer shape, int[] rearrange) {
val ret = new long[rearrange.length]; val ret = new long[rearrange.length];
for (int i = 0; i < rearrange.length; i++) { for (int i = 0; i < rearrange.length; i++) {
@ -4821,6 +4815,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return ret; return ret;
} }
@Deprecated
protected int[] doPermuteSwap(IntBuffer shape, int[] rearrange) { protected int[] doPermuteSwap(IntBuffer shape, int[] rearrange) {
int[] ret = new int[rearrange.length]; int[] ret = new int[rearrange.length];
for (int i = 0; i < rearrange.length; i++) { for (int i = 0; i < rearrange.length; i++) {
@ -4829,11 +4824,20 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return ret; return ret;
} }
@Deprecated
protected int[] doPermuteSwap(DataBuffer shape, int[] rearrange) { protected int[] doPermuteSwap(DataBuffer shape, int[] rearrange) {
int[] ret = new int[rearrange.length]; int[] ret = new int[rearrange.length];
for (int i = 0; i < rearrange.length; i++) { for (int i = 0; i < rearrange.length; i++) {
ret[i] = shape.getInt(rearrange[i]); ret[i] = shape.getInt(rearrange[i]);
} }
return ret;
}
protected long[] doPermuteSwap(long[] shape, int[] rearrange) {
val ret = new long[rearrange.length];
for (int i = 0; i < rearrange.length; i++) {
ret[i] = shape[rearrange[i]];
}
return ret; return ret;
} }
@ -5413,29 +5417,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
} }
protected int stringBuffer(FlatBufferBuilder builder, DataBuffer buffer) { protected abstract int stringBuffer(FlatBufferBuilder builder, DataBuffer buffer);
Preconditions.checkArgument(buffer.dataType() == DataType.UTF8, "This method can be called on UTF8 buffers only");
try {
ByteArrayOutputStream bos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(bos);
val numWords = this.length();
val ub = (Utf8Buffer) buffer;
// writing length first
val t = length();
val ptr = (BytePointer) ub.pointer();
// now write all strings as bytes
for (int i = 0; i < ub.length(); i++) {
dos.writeByte(ptr.get(i));
}
val bytes = bos.toByteArray();
return FlatArray.createBufferVector(builder, bytes);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
@Override @Override
public int toFlatArray(FlatBufferBuilder builder) { public int toFlatArray(FlatBufferBuilder builder) {
@ -5543,13 +5525,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return !any(); return !any();
} }
@Override
public String getString(long index) {
if (!isS())
throw new UnsupportedOperationException("This method is usable only on String dataType, but got [" + this.dataType() + "]");
return ((Utf8Buffer) data).getString(index);
}
/** /**
* Validate that the operation is being applied on a numerical array (not boolean or utf8). * Validate that the operation is being applied on a numerical array (not boolean or utf8).

View File

@ -47,12 +47,9 @@ public interface CustomOp {
*/ */
boolean isInplaceCall(); boolean isInplaceCall();
List<INDArray> outputArguments();
List<INDArray> inputArguments();
INDArray[] outputArguments();
INDArray[] inputArguments();
long[] iArgs(); long[] iArgs();

View File

@ -261,19 +261,13 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
} }
@Override @Override
public INDArray[] outputArguments() { public List<INDArray> outputArguments() {
if (!outputArguments.isEmpty()) { return outputArguments;
return outputArguments.toArray(new INDArray[0]);
}
return new INDArray[0];
} }
@Override @Override
public INDArray[] inputArguments() { public List<INDArray> inputArguments() {
if (!inputArguments.isEmpty()) return inputArguments;
return inputArguments.toArray(new INDArray[0]);
return new INDArray[0];
} }
@Override @Override
@ -367,10 +361,10 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
for (int i = 0; i < args.length; i++) { for (int i = 0; i < args.length; i++) {
// it's possible to get into situation where number of args > number of arrays AT THIS MOMENT // it's possible to get into situation where number of args > number of arrays AT THIS MOMENT
if (i >= arrsSoFar.length) if (i >= arrsSoFar.size())
continue; continue;
if (!Arrays.equals(args[i].getShape(), arrsSoFar[i].shape())) if (!Arrays.equals(args[i].getShape(), arrsSoFar.get(i).shape()))
throw new ND4JIllegalStateException("Illegal array passed in as argument [" + i + "]. Expected shape " + Arrays.toString(args[i].getShape()) + " and received array with shape " + Arrays.toString(arg[i].shape())); throw new ND4JIllegalStateException("Illegal array passed in as argument [" + i + "]. Expected shape " + Arrays.toString(args[i].getShape()) + " and received array with shape " + Arrays.toString(arg[i].shape()));
} }
} }

View File

@ -0,0 +1,51 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.compat;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
* This is a wrapper for SparseToDense op that impelements corresponding TF operation
*
* @author raver119@gmail.com
*/
public class CompatSparseToDense extends DynamicCustomOp {
public CompatSparseToDense() {
//
}
public CompatSparseToDense(INDArray indices, INDArray shape, INDArray values) {
Preconditions.checkArgument(shape.isZ() && indices.isZ(), "Shape & indices arrays must have one integer data types");
inputArguments.add(indices);
inputArguments.add(shape);
inputArguments.add(values);
}
public CompatSparseToDense(INDArray indices, INDArray shape, INDArray values, INDArray defaultVaule) {
this(indices, shape, values);
Preconditions.checkArgument(defaultVaule.dataType() == values.dataType(), "Values array must have the same data type as defaultValue array");
inputArguments.add(defaultVaule);
}
@Override
public String opName() {
return "compat_sparse_to_dense";
}
}

View File

@ -0,0 +1,51 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.compat;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
* This is a wrapper for StringSplit op that impelements corresponding TF operation
*
* @author raver119@gmail.com
*/
public class CompatStringSplit extends DynamicCustomOp {
public CompatStringSplit() {
//
}
public CompatStringSplit(INDArray strings, INDArray delimiter) {
Preconditions.checkArgument(strings.isS() && delimiter.isS(), "Input arrays must have one of UTF types");
inputArguments.add(strings);
inputArguments.add(delimiter);
}
public CompatStringSplit(INDArray strings, INDArray delimiter, INDArray indices, INDArray values) {
this(strings, delimiter);
outputArguments.add(indices);
outputArguments.add(values);
}
@Override
public String opName() {
return "compat_string_split";
}
}

View File

@ -107,12 +107,12 @@ public class ScatterUpdate implements CustomOp {
} }
@Override @Override
public INDArray[] outputArguments() { public List<INDArray> outputArguments() {
return op.outputArguments(); return op.outputArguments();
} }
@Override @Override
public INDArray[] inputArguments() { public List<INDArray> inputArguments() {
return op.inputArguments(); return op.inputArguments();
} }

View File

@ -23,7 +23,6 @@ import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.Utf8Buffer;
import org.nd4j.linalg.api.environment.Nd4jEnvironment; import org.nd4j.linalg.api.environment.Nd4jEnvironment;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -172,7 +171,7 @@ public class DefaultOpExecutioner implements OpExecutioner {
@Override @Override
public INDArray[] exec(CustomOp op) { public INDArray[] exec(CustomOp op) {
return execAndReturn(op).outputArguments(); return execAndReturn(op).outputArguments().toArray(new INDArray[0]);
} }
@Override @Override
@ -822,7 +821,7 @@ public class DefaultOpExecutioner implements OpExecutioner {
} }
@Override @Override
public String getString(Utf8Buffer buffer, long index) { public String getString(DataBuffer buffer, long index) {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }

View File

@ -20,7 +20,6 @@ import lombok.NonNull;
import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.Utf8Buffer;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ndarray.INDArrayStatistics; import org.nd4j.linalg.api.ndarray.INDArrayStatistics;
import org.nd4j.linalg.api.ops.*; import org.nd4j.linalg.api.ops.*;
@ -32,8 +31,6 @@ import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.TadPack; import org.nd4j.linalg.api.shape.TadPack;
import org.nd4j.linalg.cache.TADManager; import org.nd4j.linalg.cache.TADManager;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.profiler.OpProfiler;
import org.nd4j.linalg.profiler.ProfilerConfig; import org.nd4j.linalg.profiler.ProfilerConfig;
import java.util.List; import java.util.List;
@ -411,7 +408,7 @@ public interface OpExecutioner {
* @param index * @param index
* @return * @return
*/ */
String getString(Utf8Buffer buffer, long index); String getString(DataBuffer buffer, long index);
/** /**
* Temporary hook * Temporary hook

View File

@ -0,0 +1,43 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.util;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
/**
* This is a wrapper for PrintAffinity op that just prints out affinity & locality status of INDArray
*
* @author raver119@gmail.com
*/
public class PrintAffinity extends DynamicCustomOp {
public PrintAffinity() {
//
}
public PrintAffinity(INDArray array) {
inputArguments.add(array);
}
@Override
public String opName() {
return "print_affinity";
}
}

View File

@ -0,0 +1,66 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.util;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
/**
* This is a wrapper for PrintVariable op that just prints out Variable to the stdout
*
* @author raver119@gmail.com
*/
public class PrintVariable extends DynamicCustomOp {
public PrintVariable() {
//
}
public PrintVariable(INDArray array, boolean printSpecial) {
inputArguments.add(array);
bArguments.add(printSpecial);
}
public PrintVariable(INDArray array) {
this(array, false);
}
public PrintVariable(INDArray array, String message, boolean printSpecial) {
this(array, Nd4j.create(message), printSpecial);
}
public PrintVariable(INDArray array, String message) {
this(array, Nd4j.create(message), false);
}
public PrintVariable(INDArray array, INDArray message, boolean printSpecial) {
this(array, printSpecial);
Preconditions.checkArgument(message.isS(), "Message argument should have String data type, but got [" + message.dataType() +"] instead");
inputArguments.add(message);
}
public PrintVariable(INDArray array, INDArray message) {
this(array, message, false);
}
@Override
public String opName() {
return "print_variable";
}
}

View File

@ -89,6 +89,11 @@ public class CompressedDataBuffer extends BaseDataBuffer {
// no-op // no-op
} }
@Override
public Pointer addressPointer() {
return pointer;
}
/** /**
* Drop-in replacement wrapper for BaseDataBuffer.read() method, aware of CompressedDataBuffer * Drop-in replacement wrapper for BaseDataBuffer.read() method, aware of CompressedDataBuffer
* @param s * @param s
@ -194,6 +199,15 @@ public class CompressedDataBuffer extends BaseDataBuffer {
*/ */
@Override @Override
public DataBuffer create(int[] data) { public DataBuffer create(int[] data) {
throw new UnsupportedOperationException("This operation isn't supported for CompressedDataBuffer"); throw new UnsupportedOperationException("This method isn't supported by CompressedDataBuffer");
}
public void pointerIndexerByCurrentType(DataType currentType) {
throw new UnsupportedOperationException("This method isn't supported by CompressedDataBuffer");
}
@Override
public DataBuffer reallocate(long length) {
throw new UnsupportedOperationException("This method isn't supported by CompressedDataBuffer");
} }
} }

View File

@ -98,7 +98,7 @@ public class Convolution {
.build(); .build();
Nd4j.getExecutioner().execAndReturn(col2Im); Nd4j.getExecutioner().execAndReturn(col2Im);
return col2Im.outputArguments()[0]; return col2Im.outputArguments().get(0);
} }
public static INDArray col2im(INDArray col, INDArray z, int sH, int sW, int pH, int pW, int kH, int kW, public static INDArray col2im(INDArray col, INDArray z, int sH, int sW, int pH, int pW, int kH, int kW,
@ -187,7 +187,7 @@ public class Convolution {
.build()).build(); .build()).build();
Nd4j.getExecutioner().execAndReturn(im2col); Nd4j.getExecutioner().execAndReturn(im2col);
return im2col.outputArguments()[0]; return im2col.outputArguments().get(0);
} }
public static INDArray im2col(INDArray img, int kh, int kw, int sy, int sx, int ph, int pw, int dH, int dW, boolean isSameMode, public static INDArray im2col(INDArray img, int kh, int kw, int sy, int sx, int ph, int pw, int dH, int dW, boolean isSameMode,
@ -208,7 +208,7 @@ public class Convolution {
.build()).build(); .build()).build();
Nd4j.getExecutioner().execAndReturn(im2col); Nd4j.getExecutioner().execAndReturn(im2col);
return im2col.outputArguments()[0]; return im2col.outputArguments().get(0);
} }
/** /**
@ -298,7 +298,7 @@ public class Convolution {
.build()).build(); .build()).build();
Nd4j.getExecutioner().execAndReturn(im2col); Nd4j.getExecutioner().execAndReturn(im2col);
return im2col.outputArguments()[0]; return im2col.outputArguments().get(0);
} }
/** /**

View File

@ -40,7 +40,6 @@ import org.nd4j.graph.FlatArray;
import org.nd4j.linalg.api.blas.params.MMulTranspose; import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.buffer.*; import org.nd4j.linalg.api.buffer.*;
import org.nd4j.linalg.api.buffer.factory.DataBufferFactory; import org.nd4j.linalg.api.buffer.factory.DataBufferFactory;
import org.nd4j.linalg.api.buffer.factory.DefaultDataBufferFactory;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.concurrency.BasicAffinityManager; import org.nd4j.linalg.api.concurrency.BasicAffinityManager;
@ -1044,16 +1043,7 @@ public class Nd4j {
* @return the created buffer * @return the created buffer
*/ */
public static DataBuffer createBuffer(ByteBuffer buffer, DataType type, int length, long offset) { public static DataBuffer createBuffer(ByteBuffer buffer, DataType type, int length, long offset) {
switch (type) { return DATA_BUFFER_FACTORY_INSTANCE.create(buffer, type, length, offset);
case INT:
return DATA_BUFFER_FACTORY_INSTANCE.createInt(offset, buffer, length);
case DOUBLE:
return DATA_BUFFER_FACTORY_INSTANCE.createDouble(offset, buffer, length);
case FLOAT:
return DATA_BUFFER_FACTORY_INSTANCE.createFloat(offset, buffer, length);
default:
throw new IllegalArgumentException("Illegal opType " + type);
}
} }
/** /**
@ -1336,38 +1326,9 @@ public class Nd4j {
* @return the created buffer * @return the created buffer
*/ */
public static DataBuffer createBuffer(ByteBuffer buffer, DataType type, int length) { public static DataBuffer createBuffer(ByteBuffer buffer, DataType type, int length) {
switch (type) { return createBuffer(buffer, type, length, 0);
case INT:
return DATA_BUFFER_FACTORY_INSTANCE.createInt(buffer, length);
case LONG:
return DATA_BUFFER_FACTORY_INSTANCE.createLong(buffer, length);
case DOUBLE:
return DATA_BUFFER_FACTORY_INSTANCE.createDouble(buffer, length);
case FLOAT:
return DATA_BUFFER_FACTORY_INSTANCE.createFloat(buffer, length);
case HALF:
return DATA_BUFFER_FACTORY_INSTANCE.createHalf(buffer, length);
default:
throw new IllegalArgumentException("Illegal opType " + type);
}
} }
/**
* Create a buffer based on the data opType
*
* @param data the data to create the buffer with
* @return the created buffer
*/
public static DataBuffer createBuffer(byte[] data, int length) {
DataBuffer ret;
if (dataType() == DataType.DOUBLE)
ret = DATA_BUFFER_FACTORY_INSTANCE.createDouble(data, length);
else if (dataType() == DataType.HALF)
ret = DATA_BUFFER_FACTORY_INSTANCE.createHalf(data, length);
else
ret = DATA_BUFFER_FACTORY_INSTANCE.createFloat(data, length);
return ret;
}
/** /**
* Create a buffer equal of length prod(shape) * Create a buffer equal of length prod(shape)
@ -2206,6 +2167,7 @@ public class Nd4j {
private static String writeStringForArray(INDArray write) { private static String writeStringForArray(INDArray write) {
if(write.isView() || !Shape.hasDefaultStridesForShape(write)) if(write.isView() || !Shape.hasDefaultStridesForShape(write))
write = write.dup(); write = write.dup();
String format = "0.000000000000000000E0"; String format = "0.000000000000000000E0";
return "{\n" + return "{\n" +
@ -3927,16 +3889,6 @@ public class Nd4j {
return create(shape, stride); return create(shape, stride);
} }
/**
* Creates an ndarray with the specified shape
*
* @param rows the rows of the ndarray
* @param columns the columns of the ndarray
* @return the instance
*/
public static INDArray create(int rows, int columns) {
return create(rows, columns, order());
}
/** /**
* Creates an ndarray with the specified shape * Creates an ndarray with the specified shape
@ -4386,13 +4338,6 @@ public class Nd4j {
return createUninitialized(shape, Nd4j.order()); return createUninitialized(shape, Nd4j.order());
} }
/**
* See {@link #createUninitialized(long)}
*/
public static INDArray createUninitialized(int length) {
return createUninitialized((long)length);
}
/** /**
* This method creates an *uninitialized* ndarray of specified length and default ordering. * This method creates an *uninitialized* ndarray of specified length and default ordering.
* *
@ -4428,37 +4373,6 @@ public class Nd4j {
////////////////////// OTHER /////////////////////////////// ////////////////////// OTHER ///////////////////////////////
/**
* Creates a 2D array with specified number of rows, columns initialized with zero.
*
* @param rows number of rows.
* @param columns number of columns.
* @return the created array.
*/
public static INDArray zeros(long rows, long columns) {
return INSTANCE.zeros(rows, columns);
}
/**
* Creates a 1D array with the specified number of columns initialized with zero.
*
* @param columns number of columns.
* @return the created array
*/
public static INDArray zeros(int columns) {
return INSTANCE.zeros(columns);
}
/**
* Creates a 1D array with the specified data tyoe and number of columns initialized with zero.
*
* @param dataType data type.
* @param columns number of columns.
* @return the created array.
*/
public static INDArray zeros(DataType dataType, int columns) {
return INSTANCE.create(dataType, new long[]{columns}, 'c', Nd4j.getMemoryManager().getCurrentWorkspace());
}
/** /**
* Creates an array with the specified data tyoe and shape initialized with zero. * Creates an array with the specified data tyoe and shape initialized with zero.
@ -4468,7 +4382,10 @@ public class Nd4j {
* @return the created array. * @return the created array.
*/ */
public static INDArray zeros(DataType dataType, @NonNull long... shape) { public static INDArray zeros(DataType dataType, @NonNull long... shape) {
return INSTANCE.create(dataType, shape, 'c', Nd4j.getMemoryManager().getCurrentWorkspace()); if(shape.length == 0)
return Nd4j.scalar(dataType, 0);
return INSTANCE.create(dataType, shape, Nd4j.order(), Nd4j.getMemoryManager().getCurrentWorkspace());
} }
/** /**
@ -4588,31 +4505,6 @@ public class Nd4j {
return INSTANCE.valueArrayOf(rows, columns, value); return INSTANCE.valueArrayOf(rows, columns, value);
} }
/**
* Creates a row vector with the specified number of columns
*
* @param rows the number of rows in the matrix
* @param columns the columns of the ndarray
* @return the created ndarray
*/
public static INDArray ones(int rows, int columns) {
return INSTANCE.ones(rows, columns);
}
/**
* Create a 2D array with the given rows, columns and data type initialised with ones.
*
* @param dataType data type
* @param rows rows of the new array.
* @param columns columns of the new arrau.
* @return the created array
*/
public static INDArray ones(DataType dataType, int rows, int columns) {
INDArray ret = INSTANCE.createUninitialized(dataType, new long[]{rows, columns}, Nd4j.order(), Nd4j.getMemoryManager().getCurrentWorkspace());
ret.assign(1);
return ret;
}
/** /**
* Empty like * Empty like
* *
@ -4817,8 +4709,7 @@ public class Nd4j {
for (int idx : indexes) { for (int idx : indexes) {
if (idx < 0 || idx >= source.shape()[source.rank() - sourceDimension - 1]) { if (idx < 0 || idx >= source.shape()[source.rank() - sourceDimension - 1]) {
throw new IllegalStateException( throw new IllegalStateException("Index can't be < 0 and >= " + source.shape()[source.rank() - sourceDimension - 1]);
"Index can't be < 0 and >= " + source.shape()[source.rank() - sourceDimension - 1]);
} }
} }
@ -5186,7 +5077,7 @@ public class Nd4j {
pp.toString(NDARRAY_FACTORY_CLASS)); pp.toString(NDARRAY_FACTORY_CLASS));
Class<? extends ConvolutionInstance> convolutionInstanceClazz = (Class<? extends ConvolutionInstance>) Class Class<? extends ConvolutionInstance> convolutionInstanceClazz = (Class<? extends ConvolutionInstance>) Class
.forName(pp.toString(CONVOLUTION_OPS, DefaultConvolutionInstance.class.getName())); .forName(pp.toString(CONVOLUTION_OPS, DefaultConvolutionInstance.class.getName()));
String defaultName = pp.toString(DATA_BUFFER_OPS, DefaultDataBufferFactory.class.getName()); String defaultName = pp.toString(DATA_BUFFER_OPS, "org.nd4j.linalg.cpu.nativecpu.buffer.DefaultDataBufferFactory");
Class<? extends DataBufferFactory> dataBufferFactoryClazz = (Class<? extends DataBufferFactory>) Class Class<? extends DataBufferFactory> dataBufferFactoryClazz = (Class<? extends DataBufferFactory>) Class
.forName(pp.toString(DATA_BUFFER_OPS, defaultName)); .forName(pp.toString(DATA_BUFFER_OPS, defaultName));
Class<? extends BaseShapeInfoProvider> shapeInfoProviderClazz = (Class<? extends BaseShapeInfoProvider>) Class Class<? extends BaseShapeInfoProvider> shapeInfoProviderClazz = (Class<? extends BaseShapeInfoProvider>) Class
@ -5871,7 +5762,7 @@ public class Nd4j {
arr[e] = sb.get(e + pos); arr[e] = sb.get(e + pos);
} }
val buffer = new Utf8Buffer(arr, prod); val buffer = DATA_BUFFER_FACTORY_INSTANCE.createUtf8Buffer(arr, prod);
return Nd4j.create(buffer, shapeOf); return Nd4j.create(buffer, shapeOf);
} catch (Exception e) { } catch (Exception e) {
throw new RuntimeException(e); throw new RuntimeException(e);

View File

@ -30,6 +30,7 @@ import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
/** /**
* This class provides unified management for Deallocatable resources * This class provides unified management for Deallocatable resources
@ -43,6 +44,8 @@ public class DeallocatorService {
private Map<String, DeallocatableReference> referenceMap = new ConcurrentHashMap<>(); private Map<String, DeallocatableReference> referenceMap = new ConcurrentHashMap<>();
private List<List<ReferenceQueue<Deallocatable>>> deviceMap = new ArrayList<>(); private List<List<ReferenceQueue<Deallocatable>>> deviceMap = new ArrayList<>();
private AtomicLong counter = new AtomicLong(0);
public DeallocatorService() { public DeallocatorService() {
// we need to have at least 2 threads, but for CUDA we'd need at least numDevices threads, due to thread->device affinity // we need to have at least 2 threads, but for CUDA we'd need at least numDevices threads, due to thread->device affinity
int numDevices = Nd4j.getAffinityManager().getNumberOfDevices(); int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
@ -69,6 +72,10 @@ public class DeallocatorService {
} }
} }
public long nextValue() {
return counter.incrementAndGet();
}
/** /**
* This method adds Deallocatable object instance to tracking system * This method adds Deallocatable object instance to tracking system
* *

View File

@ -17,10 +17,10 @@
package org.nd4j.serde.jackson.shaded; package org.nd4j.serde.jackson.shaded;
import org.nd4j.linalg.api.buffer.Utf8Buffer;
import lombok.val;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.serde.base64.Nd4jBase64;
import org.nd4j.shade.jackson.core.JsonGenerator; import org.nd4j.shade.jackson.core.JsonGenerator;
import org.nd4j.shade.jackson.databind.JsonSerializer; import org.nd4j.shade.jackson.databind.JsonSerializer;
import org.nd4j.shade.jackson.databind.SerializerProvider; import org.nd4j.shade.jackson.databind.SerializerProvider;
@ -77,10 +77,9 @@ public class NDArrayTextSerializer extends JsonSerializer<INDArray> {
jg.writeNumber(v); jg.writeNumber(v);
break; break;
case UTF8: case UTF8:
Utf8Buffer utf8B = ((Utf8Buffer)arr.data()); val n = arr.length();
long n = utf8B.getNumWords();
for( int j=0; j<n; j++ ) { for( int j=0; j<n; j++ ) {
String s = utf8B.getString(j); String s = arr.getString(j);
jg.writeString(s); jg.writeString(s);
} }
break; break;

View File

@ -16,11 +16,8 @@
package org.nd4j.nativeblas; package org.nd4j.nativeblas;
import lombok.val;
import org.bytedeco.javacpp.*; import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.Cast; import org.bytedeco.javacpp.annotation.Cast;
import org.bytedeco.javacpp.indexer.LongIndexer;
import org.nd4j.linalg.api.buffer.Utf8Buffer;
/** /**
@ -53,14 +50,12 @@ public interface NativeOps {
*/ */
void execIndexReduceScalar(PointerPointer extraPointers, void execIndexReduceScalar(PointerPointer extraPointers,
int opNum, int opNum,
Pointer x, OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dX,
@Cast("Nd4jLong *") LongPointer dXShapeInfo, @Cast("Nd4jLong *") LongPointer dXShapeInfo,
Pointer extraParams, Pointer extraParams,
Pointer z, OpaqueDataBuffer z,
@Cast("Nd4jLong *") LongPointer zShapeInfo, @Cast("Nd4jLong *") LongPointer zShapeInfo,
Pointer dZ,
@Cast("Nd4jLong *") LongPointer dZShapeInfo); @Cast("Nd4jLong *") LongPointer dZShapeInfo);
/** /**
@ -75,17 +70,16 @@ public interface NativeOps {
*/ */
void execIndexReduce(PointerPointer extraPointers, void execIndexReduce(PointerPointer extraPointers,
int opNum, int opNum,
Pointer x, OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dX,
@Cast("Nd4jLong *") LongPointer dXShapeInfo, @Cast("Nd4jLong *") LongPointer dXShapeInfo,
Pointer extraParams, Pointer extraParams,
Pointer result, OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfoBuffer, @Cast("Nd4jLong *") LongPointer resultShapeInfoBuffer,
Pointer dResult,
@Cast("Nd4jLong *") LongPointer dResultShapeInfoBuffer, @Cast("Nd4jLong *") LongPointer dResultShapeInfoBuffer,
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape, OpaqueDataBuffer hDimension,
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape); @Cast("Nd4jLong *") LongPointer hDimensionShape,
@Cast("Nd4jLong *") LongPointer dDimensionShape);
/** /**
* @param opNum * @param opNum
@ -100,38 +94,34 @@ public interface NativeOps {
*/ */
void execBroadcast(PointerPointer extraPointers, void execBroadcast(PointerPointer extraPointers,
int opNum, int opNum,
Pointer x, OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx,
@Cast("Nd4jLong *") LongPointer dxShapeInfo, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer y, OpaqueDataBuffer y,
@Cast("Nd4jLong *") LongPointer yShapeInfo, @Cast("Nd4jLong *") LongPointer yShapeInfo,
Pointer dy,
@Cast("Nd4jLong *") LongPointer dyShapeInfo, @Cast("Nd4jLong *") LongPointer dyShapeInfo,
Pointer result, OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo, @Cast("Nd4jLong *") LongPointer dresultShapeInfo,
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape, OpaqueDataBuffer hDimension,
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape); @Cast("Nd4jLong *") LongPointer hDimensionShape,
@Cast("Nd4jLong *") LongPointer dDimensionShape);
void execBroadcastBool(PointerPointer extraPointers, void execBroadcastBool(PointerPointer extraPointers,
int opNum, int opNum,
Pointer x, OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx,
@Cast("Nd4jLong *") LongPointer dxShapeInfo, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer y, OpaqueDataBuffer y,
@Cast("Nd4jLong *") LongPointer yShapeInfo, @Cast("Nd4jLong *") LongPointer yShapeInfo,
Pointer dy,
@Cast("Nd4jLong *") LongPointer dyShapeInfo, @Cast("Nd4jLong *") LongPointer dyShapeInfo,
Pointer result, OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo, @Cast("Nd4jLong *") LongPointer dresultShapeInfo,
Pointer extraParams, Pointer extraParams,
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape, OpaqueDataBuffer hDimension,
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape); @Cast("Nd4jLong *") LongPointer hDimensionShape,
@Cast("Nd4jLong *") LongPointer dDimensionShape);
/** /**
@ -146,33 +136,27 @@ public interface NativeOps {
*/ */
void execPairwiseTransform(PointerPointer extraPointers, void execPairwiseTransform(PointerPointer extraPointers,
int opNum, int opNum,
Pointer x, OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx,
@Cast("Nd4jLong *") LongPointer dxShapeInfo, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer y, OpaqueDataBuffer y,
@Cast("Nd4jLong *") LongPointer yShapeInfo, @Cast("Nd4jLong *") LongPointer yShapeInfo,
Pointer dy,
@Cast("Nd4jLong *") LongPointer dyShapeInfo, @Cast("Nd4jLong *") LongPointer dyShapeInfo,
Pointer result, OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo, @Cast("Nd4jLong *") LongPointer dresultShapeInfo,
Pointer extraParams); Pointer extraParams);
void execPairwiseTransformBool(PointerPointer extraPointers, void execPairwiseTransformBool(PointerPointer extraPointers,
int opNum, int opNum,
Pointer x, OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx,
@Cast("Nd4jLong *") LongPointer dxShapeInfo, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer y, OpaqueDataBuffer y,
@Cast("Nd4jLong *") LongPointer yShapeInfo, @Cast("Nd4jLong *") LongPointer yShapeInfo,
Pointer dy,
@Cast("Nd4jLong *") LongPointer dyShapeInfo, @Cast("Nd4jLong *") LongPointer dyShapeInfo,
Pointer result, OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo, @Cast("Nd4jLong *") LongPointer dresultShapeInfo,
Pointer extraParams); Pointer extraParams);
@ -186,53 +170,45 @@ public interface NativeOps {
*/ */
void execReduceFloat(PointerPointer extraPointers, void execReduceFloat(PointerPointer extraPointers,
int opNum, int opNum,
Pointer x, OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx,
@Cast("Nd4jLong *") LongPointer dxShapeInfo, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer extraParams, Pointer extraParams,
Pointer result, OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo); @Cast("Nd4jLong *") LongPointer dresultShapeInfo);
void execReduceSame(PointerPointer extraPointers, void execReduceSame(PointerPointer extraPointers,
int opNum, int opNum,
Pointer x, OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx,
@Cast("Nd4jLong *") LongPointer dxShapeInfo, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer extraParams, Pointer extraParams,
Pointer result, OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo); @Cast("Nd4jLong *") LongPointer dresultShapeInfo);
void execReduceBool(PointerPointer extraPointers, void execReduceBool(PointerPointer extraPointers,
int opNum, int opNum,
Pointer x, OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx,
@Cast("Nd4jLong *") LongPointer dxShapeInfo, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer extraParams, Pointer extraParams,
Pointer result, OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo); @Cast("Nd4jLong *") LongPointer dresultShapeInfo);
void execReduceLong(PointerPointer extraPointers, void execReduceLong(PointerPointer extraPointers,
int opNum, int opNum,
Pointer x, OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx,
@Cast("Nd4jLong *") LongPointer dxShapeInfo, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer extraParams, Pointer extraParams,
Pointer result, OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo); @Cast("Nd4jLong *") LongPointer dresultShapeInfo);
/** /**
@ -245,60 +221,56 @@ public interface NativeOps {
*/ */
void execReduceFloat2(PointerPointer extraPointers, void execReduceFloat2(PointerPointer extraPointers,
int opNum, int opNum,
Pointer x, OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx,
@Cast("Nd4jLong *") LongPointer dxShapeInfo, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer extraParams, Pointer extraParams,
Pointer result, OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo, @Cast("Nd4jLong *") LongPointer dresultShapeInfo,
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape, OpaqueDataBuffer hDimension,
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape); @Cast("Nd4jLong *") LongPointer hDimensionShape,
@Cast("Nd4jLong *") LongPointer dDimensionShape);
void execReduceSame2(PointerPointer extraPointers, void execReduceSame2(PointerPointer extraPointers,
int opNum, int opNum,
Pointer x, OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx,
@Cast("Nd4jLong *") LongPointer dxShapeInfo, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer extraParams, Pointer extraParams,
Pointer result, OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo, @Cast("Nd4jLong *") LongPointer dresultShapeInfo,
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape, OpaqueDataBuffer hDimension,
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape); @Cast("Nd4jLong *") LongPointer hDimensionShape,
@Cast("Nd4jLong *") LongPointer dDimensionShape);
void execReduceBool2(PointerPointer extraPointers, void execReduceBool2(PointerPointer extraPointers,
int opNum, int opNum,
Pointer x, OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx,
@Cast("Nd4jLong *") LongPointer dxShapeInfo, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer extraParams, Pointer extraParams,
Pointer result, OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo, @Cast("Nd4jLong *") LongPointer dresultShapeInfo,
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape, OpaqueDataBuffer hDimension,
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape); @Cast("Nd4jLong *") LongPointer hDimensionShape,
@Cast("Nd4jLong *") LongPointer dDimensionShape);
void execReduceLong2(PointerPointer extraPointers, void execReduceLong2(PointerPointer extraPointers,
int opNum, int opNum,
Pointer x, OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx,
@Cast("Nd4jLong *") LongPointer dxShapeInfo, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer extraParams, Pointer extraParams,
Pointer result, OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo, @Cast("Nd4jLong *") LongPointer dresultShapeInfo,
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape, OpaqueDataBuffer hDimension,
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape); @Cast("Nd4jLong *") LongPointer hDimensionShape,
@Cast("Nd4jLong *") LongPointer dDimensionShape);
/** /**
* @param opNum * @param opNum
@ -312,13 +284,16 @@ public interface NativeOps {
*/ */
void execReduce3(PointerPointer extraPointers, void execReduce3(PointerPointer extraPointers,
int opNum, int opNum,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, OpaqueDataBuffer x,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, @Cast("Nd4jLong *") LongPointer xShapeInfo,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer extraParamsVals, Pointer extraParamsVals,
Pointer y, @Cast("Nd4jLong *") LongPointer yShapeInfo, OpaqueDataBuffer y,
Pointer dy, @Cast("Nd4jLong *") LongPointer dyShapeInfo, @Cast("Nd4jLong *") LongPointer yShapeInfo,
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo, @Cast("Nd4jLong *") LongPointer dyShapeInfo,
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo); OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo);
/** /**
* @param opNum * @param opNum
@ -329,13 +304,16 @@ public interface NativeOps {
* @param yShapeInfo * @param yShapeInfo
*/ */
void execReduce3Scalar(PointerPointer extraPointers, int opNum, void execReduce3Scalar(PointerPointer extraPointers, int opNum,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, OpaqueDataBuffer x,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, @Cast("Nd4jLong *") LongPointer xShapeInfo,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer extraParamsVals, Pointer extraParamsVals,
Pointer y, @Cast("Nd4jLong *") LongPointer yShapeInfo, OpaqueDataBuffer y,
Pointer dy, @Cast("Nd4jLong *") LongPointer dyShapeInfo, @Cast("Nd4jLong *") LongPointer yShapeInfo,
Pointer z, @Cast("Nd4jLong *") LongPointer zShapeInfo, @Cast("Nd4jLong *") LongPointer dyShapeInfo,
Pointer dz, @Cast("Nd4jLong *") LongPointer dzShapeInfo); OpaqueDataBuffer z,
@Cast("Nd4jLong *") LongPointer zShapeInfo,
@Cast("Nd4jLong *") LongPointer dzShapeInfo);
/** /**
* @param opNum * @param opNum
@ -351,29 +329,37 @@ public interface NativeOps {
*/ */
void execReduce3Tad(PointerPointer extraPointers, void execReduce3Tad(PointerPointer extraPointers,
int opNum, int opNum,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, OpaqueDataBuffer x,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, @Cast("Nd4jLong *") LongPointer xShapeInfo,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer extraParamsVals, Pointer extraParamsVals,
Pointer y, @Cast("Nd4jLong *") LongPointer yShapeInfo, OpaqueDataBuffer y,
Pointer dy, @Cast("Nd4jLong *") LongPointer dyShapeInfo, @Cast("Nd4jLong *") LongPointer yShapeInfo,
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfoBuffer, @Cast("Nd4jLong *") LongPointer dyShapeInfo,
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfoBuffer, OpaqueDataBuffer result,
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape, @Cast("Nd4jLong *") LongPointer resultShapeInfoBuffer,
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape, @Cast("Nd4jLong *") LongPointer dresultShapeInfoBuffer,
OpaqueDataBuffer hDimension,
@Cast("Nd4jLong *") LongPointer hDimensionShape,
@Cast("Nd4jLong *") LongPointer dDimensionShape,
@Cast("Nd4jLong *") LongPointer tadOnlyShapeInfo, @Cast("Nd4jLong *") LongPointer tadOffsets, @Cast("Nd4jLong *") LongPointer tadOnlyShapeInfo, @Cast("Nd4jLong *") LongPointer tadOffsets,
@Cast("Nd4jLong *") LongPointer yTadOnlyShapeInfo, @Cast("Nd4jLong *") LongPointer yTadOffsets); @Cast("Nd4jLong *") LongPointer yTadOnlyShapeInfo, @Cast("Nd4jLong *") LongPointer yTadOffsets);
void execReduce3All(PointerPointer extraPointers, void execReduce3All(PointerPointer extraPointers,
int opNum, int opNum,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, OpaqueDataBuffer x,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, @Cast("Nd4jLong *") LongPointer xShapeInfo,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer extraParamsVals, Pointer extraParamsVals,
Pointer y, @Cast("Nd4jLong *") LongPointer yShapeInfo, OpaqueDataBuffer y,
Pointer dy, @Cast("Nd4jLong *") LongPointer dyShapeInfo, @Cast("Nd4jLong *") LongPointer yShapeInfo,
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfoBuffer, @Cast("Nd4jLong *") LongPointer dyShapeInfo,
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfoBuffer, OpaqueDataBuffer result,
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape, @Cast("Nd4jLong *") LongPointer resultShapeInfoBuffer,
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape, @Cast("Nd4jLong *") LongPointer dresultShapeInfoBuffer,
OpaqueDataBuffer hDimension,
@Cast("Nd4jLong *") LongPointer hDimensionShape,
@Cast("Nd4jLong *") LongPointer dDimensionShape,
@Cast("Nd4jLong *") LongPointer xTadShape, @Cast("Nd4jLong *") LongPointer xTadShape,
@Cast("Nd4jLong *") LongPointer xOffsets, @Cast("Nd4jLong *") LongPointer xOffsets,
@Cast("Nd4jLong *") LongPointer yTadShape, @Cast("Nd4jLong *") LongPointer yTadShape,
@ -391,22 +377,28 @@ public interface NativeOps {
*/ */
void execScalar(PointerPointer extraPointers, void execScalar(PointerPointer extraPointers,
int opNum, int opNum,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, OpaqueDataBuffer x,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo, OpaqueDataBuffer result,
Pointer scalar, @Cast("Nd4jLong *") LongPointer scalarShapeInfo, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dscalar, @Cast("Nd4jLong *") LongPointer dscalarShapeInfo, @Cast("Nd4jLong *") LongPointer dresultShapeInfo,
OpaqueDataBuffer scalar,
@Cast("Nd4jLong *") LongPointer scalarShapeInfo,
@Cast("Nd4jLong *") LongPointer dscalarShapeInfo,
Pointer extraParams); Pointer extraParams);
void execScalarBool(PointerPointer extraPointers, void execScalarBool(PointerPointer extraPointers,
int opNum, int opNum,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, OpaqueDataBuffer x,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo, OpaqueDataBuffer result,
Pointer scalar, @Cast("Nd4jLong *") LongPointer scalarShapeInfo, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dscalar, @Cast("Nd4jLong *") LongPointer dscalarShapeInfo, @Cast("Nd4jLong *") LongPointer dresultShapeInfo,
OpaqueDataBuffer scalar,
@Cast("Nd4jLong *") LongPointer scalarShapeInfo,
@Cast("Nd4jLong *") LongPointer dscalarShapeInfo,
Pointer extraParams); Pointer extraParams);
/** /**
@ -418,11 +410,13 @@ public interface NativeOps {
*/ */
void execSummaryStatsScalar(PointerPointer extraPointers, void execSummaryStatsScalar(PointerPointer extraPointers,
int opNum, int opNum,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, OpaqueDataBuffer x,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, @Cast("Nd4jLong *") LongPointer xShapeInfo,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer extraParams, Pointer extraParams,
Pointer z, @Cast("Nd4jLong *") LongPointer zShapeInfo, OpaqueDataBuffer z,
Pointer dz, @Cast("Nd4jLong *") LongPointer dzShapeInfo, @Cast("Nd4jLong *") LongPointer zShapeInfo,
@Cast("Nd4jLong *") LongPointer dzShapeInfo,
boolean biasCorrected); boolean biasCorrected);
/** /**
@ -436,11 +430,13 @@ public interface NativeOps {
*/ */
void execSummaryStats(PointerPointer extraPointers, void execSummaryStats(PointerPointer extraPointers,
int opNum, int opNum,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, OpaqueDataBuffer x,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, @Cast("Nd4jLong *") LongPointer xShapeInfo,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer extraParams, Pointer extraParams,
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo, OpaqueDataBuffer result,
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
boolean biasCorrected); boolean biasCorrected);
/** /**
@ -455,13 +451,16 @@ public interface NativeOps {
*/ */
void execSummaryStatsTad(PointerPointer extraPointers, void execSummaryStatsTad(PointerPointer extraPointers,
int opNum, int opNum,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, OpaqueDataBuffer x,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, @Cast("Nd4jLong *") LongPointer xShapeInfo,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer extraParams, Pointer extraParams,
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfoBuffer, OpaqueDataBuffer result,
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfoBuffer, @Cast("Nd4jLong *") LongPointer resultShapeInfoBuffer,
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape, @Cast("Nd4jLong *") LongPointer dresultShapeInfoBuffer,
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape, OpaqueDataBuffer hDimension,
@Cast("Nd4jLong *") LongPointer hDimensionShape,
@Cast("Nd4jLong *") LongPointer dDimensionShape,
boolean biasCorrected, boolean biasCorrected,
@Cast("Nd4jLong *") LongPointer tadShapeInfo, @Cast("Nd4jLong *") LongPointer tadShapeInfo,
@Cast("Nd4jLong *") LongPointer tadOffsets); @Cast("Nd4jLong *") LongPointer tadOffsets);
@ -478,42 +477,52 @@ public interface NativeOps {
*/ */
void execTransformFloat(PointerPointer extraPointers, void execTransformFloat(PointerPointer extraPointers,
int opNum, int opNum,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, OpaqueDataBuffer x,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo, OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
Pointer extraParams); Pointer extraParams);
void execTransformSame(PointerPointer extraPointers, void execTransformSame(PointerPointer extraPointers,
int opNum, int opNum,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, OpaqueDataBuffer x,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo, OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
Pointer extraParams); Pointer extraParams);
void execTransformStrict(PointerPointer extraPointers, void execTransformStrict(PointerPointer extraPointers,
int opNum, int opNum,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, OpaqueDataBuffer x,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo, OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
Pointer extraParams); Pointer extraParams);
void execTransformBool(PointerPointer extraPointers, void execTransformBool(PointerPointer extraPointers,
int opNum, int opNum,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, OpaqueDataBuffer x,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo, OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
Pointer extraParams); Pointer extraParams);
void execTransformAny(PointerPointer extraPointers, void execTransformAny(PointerPointer extraPointers,
int opNum, int opNum,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, OpaqueDataBuffer x,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo, OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
Pointer extraParams); Pointer extraParams);
/** /**
@ -532,31 +541,43 @@ public interface NativeOps {
*/ */
void execScalarTad(PointerPointer extraPointers, void execScalarTad(PointerPointer extraPointers,
int opNum, int opNum,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, OpaqueDataBuffer x,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer z, @Cast("Nd4jLong *") LongPointer zShapeInfo, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer dz, @Cast("Nd4jLong *") LongPointer dzShapeInfo, OpaqueDataBuffer z,
Pointer scalars, @Cast("Nd4jLong *") LongPointer scalarShapeInfo, @Cast("Nd4jLong *") LongPointer zShapeInfo,
Pointer dscalars, @Cast("Nd4jLong *") LongPointer dscalarShapeInfo, @Cast("Nd4jLong *") LongPointer dzShapeInfo,
OpaqueDataBuffer scalars,
@Cast("Nd4jLong *") LongPointer scalarShapeInfo,
@Cast("Nd4jLong *") LongPointer dscalarShapeInfo,
Pointer extraParams, Pointer extraParams,
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape, OpaqueDataBuffer hDimension,
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape, @Cast("Nd4jLong *") LongPointer hDimensionShape,
@Cast("Nd4jLong *") LongPointer tadShapeInfo, @Cast("Nd4jLong *") LongPointer tadOffsets, @Cast("Nd4jLong *") LongPointer dDimensionShape,
@Cast("Nd4jLong *") LongPointer tadShapeInfoZ, @Cast("Nd4jLong *") LongPointer tadOffsetsZ); @Cast("Nd4jLong *") LongPointer tadShapeInfo,
@Cast("Nd4jLong *") LongPointer tadOffsets,
@Cast("Nd4jLong *") LongPointer tadShapeInfoZ,
@Cast("Nd4jLong *") LongPointer tadOffsetsZ);
void execScalarBoolTad(PointerPointer extraPointers, void execScalarBoolTad(PointerPointer extraPointers,
int opNum, int opNum,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, OpaqueDataBuffer x,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer z, @Cast("Nd4jLong *") LongPointer zShapeInfo, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer dz, @Cast("Nd4jLong *") LongPointer dzShapeInfo, OpaqueDataBuffer z,
Pointer scalars, @Cast("Nd4jLong *") LongPointer scalarShapeInfo, @Cast("Nd4jLong *") LongPointer zShapeInfo,
Pointer dscalars, @Cast("Nd4jLong *") LongPointer dscalarShapeInfo, @Cast("Nd4jLong *") LongPointer dzShapeInfo,
OpaqueDataBuffer scalars,
@Cast("Nd4jLong *") LongPointer scalarShapeInfo,
@Cast("Nd4jLong *") LongPointer dscalarShapeInfo,
Pointer extraParams, Pointer extraParams,
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape, OpaqueDataBuffer hDimension,
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape, @Cast("Nd4jLong *") LongPointer hDimensionShape,
@Cast("Nd4jLong *") LongPointer tadShapeInfo, @Cast("Nd4jLong *") LongPointer tadOffsets, @Cast("Nd4jLong *") LongPointer dDimensionShape,
@Cast("Nd4jLong *") LongPointer tadShapeInfoZ, @Cast("Nd4jLong *") LongPointer tadOffsetsZ); @Cast("Nd4jLong *") LongPointer tadShapeInfo,
@Cast("Nd4jLong *") LongPointer tadOffsets,
@Cast("Nd4jLong *") LongPointer tadShapeInfoZ,
@Cast("Nd4jLong *") LongPointer tadOffsetsZ);
void specialConcat(PointerPointer extraPointers, void specialConcat(PointerPointer extraPointers,
@ -675,10 +696,12 @@ public interface NativeOps {
/////////////// ///////////////
void pullRows(PointerPointer extraPointers, void pullRows(PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo, OpaqueDataBuffer x,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer z, @Cast("Nd4jLong *") LongPointer zShapeInfo, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer dz, @Cast("Nd4jLong *") LongPointer dzShapeInfo, OpaqueDataBuffer z,
@Cast("Nd4jLong *") LongPointer zShapeInfo,
@Cast("Nd4jLong *") LongPointer dzShapeInfo,
long n, long n,
@Cast("Nd4jLong *") LongPointer indexes, @Cast("Nd4jLong *") LongPointer indexes,
@Cast("Nd4jLong *") LongPointer tadShapeInfo, @Cast("Nd4jLong *") LongPointer tadShapeInfo,
@ -777,28 +800,34 @@ public interface NativeOps {
void execRandom(PointerPointer extraPointers, void execRandom(PointerPointer extraPointers,
int opNum, int opNum,
Pointer state, Pointer state,
Pointer z, @Cast("Nd4jLong *") LongPointer zShapeBuffer, OpaqueDataBuffer z,
Pointer dz, @Cast("Nd4jLong *") LongPointer dzShapeBuffer, @Cast("Nd4jLong *") LongPointer zShapeBuffer,
@Cast("Nd4jLong *") LongPointer dzShapeBuffer,
Pointer extraArguments); Pointer extraArguments);
void execRandom3(PointerPointer extraPointers, void execRandom3(PointerPointer extraPointers,
int opNum, int opNum,
Pointer state, Pointer state,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeBuffer, OpaqueDataBuffer x,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeBuffer, @Cast("Nd4jLong *") LongPointer xShapeBuffer,
Pointer y, @Cast("Nd4jLong *") LongPointer yShapeBuffer, @Cast("Nd4jLong *") LongPointer dxShapeBuffer,
Pointer dy, @Cast("Nd4jLong *") LongPointer dyShapeBuffer, OpaqueDataBuffer y,
Pointer z, @Cast("Nd4jLong *") LongPointer zShapeBuffer, @Cast("Nd4jLong *") LongPointer yShapeBuffer,
Pointer dz, @Cast("Nd4jLong *") LongPointer dzShapeBuffer, @Cast("Nd4jLong *") LongPointer dyShapeBuffer,
OpaqueDataBuffer z,
@Cast("Nd4jLong *") LongPointer zShapeBuffer,
@Cast("Nd4jLong *") LongPointer dzShapeBuffer,
Pointer extraArguments); Pointer extraArguments);
void execRandom2(PointerPointer extraPointers, void execRandom2(PointerPointer extraPointers,
int opNum, int opNum,
Pointer state, Pointer state,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeBuffer, OpaqueDataBuffer x,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeBuffer, @Cast("Nd4jLong *") LongPointer xShapeBuffer,
Pointer z, @Cast("Nd4jLong *") LongPointer zShapeBuffer, @Cast("Nd4jLong *") LongPointer dxShapeBuffer,
Pointer dz, @Cast("Nd4jLong *") LongPointer dzShapeBuffer, OpaqueDataBuffer z,
@Cast("Nd4jLong *") LongPointer zShapeBuffer,
@Cast("Nd4jLong *") LongPointer dzShapeBuffer,
Pointer extraArguments); Pointer extraArguments);
//////////////////// ////////////////////
@ -967,9 +996,11 @@ public interface NativeOps {
void tear(PointerPointer extras, void tear(PointerPointer extras,
Pointer tensor, @Cast("Nd4jLong *") LongPointer xShapeInfo, OpaqueDataBuffer tensor,
Pointer dtensor, @Cast("Nd4jLong *") LongPointer dxShapeInfo, @Cast("Nd4jLong *") LongPointer xShapeInfo,
PointerPointer targets, @Cast("Nd4jLong *") LongPointer zShapeInfo, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
PointerPointer targets,
@Cast("Nd4jLong *") LongPointer zShapeInfo,
@Cast("Nd4jLong *") LongPointer tadShapeInfo, @Cast("Nd4jLong *") LongPointer tadShapeInfo,
@Cast("Nd4jLong *") LongPointer tadOffsets); @Cast("Nd4jLong *") LongPointer tadOffsets);
@ -1121,6 +1152,8 @@ public interface NativeOps {
void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer); void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer);
void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
void setGraphContextOutputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); void setGraphContextOutputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
void setGraphContextInputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer databuffer, Pointer shapeInfo, Pointer specialShapeInfo);
void setGraphContextOutputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer databuffer, Pointer shapeInfo, Pointer specialShapeInfo);
void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments); void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments);
void setGraphContextIArguments(OpaqueContext ptr, LongPointer arguments, int numberOfArguments); void setGraphContextIArguments(OpaqueContext ptr, LongPointer arguments, int numberOfArguments);
void setGraphContextBArguments(OpaqueContext ptr, BooleanPointer arguments, int numberOfArguments); void setGraphContextBArguments(OpaqueContext ptr, BooleanPointer arguments, int numberOfArguments);
@ -1162,4 +1195,27 @@ public interface NativeOps {
boolean isMinimalRequirementsMet(); boolean isMinimalRequirementsMet();
boolean isOptimalRequirementsMet(); boolean isOptimalRequirementsMet();
OpaqueDataBuffer allocateDataBuffer(long elements, int dataType, boolean allocateBoth);
OpaqueDataBuffer dbCreateView(OpaqueDataBuffer dataBuffer, long length, long offset);
Pointer dbPrimaryBuffer(OpaqueDataBuffer dataBuffer);
Pointer dbSpecialBuffer(OpaqueDataBuffer dataBuffer);
void dbExpandBuffer(OpaqueDataBuffer dataBuffer, long elements);
void dbAllocatePrimaryBuffer(OpaqueDataBuffer dataBuffer);
void dbAllocateSpecialBuffer(OpaqueDataBuffer dataBuffer);
void dbSetPrimaryBuffer(OpaqueDataBuffer dataBuffer, Pointer primaryBuffer, long numBytes);
void dbSetSpecialBuffer(OpaqueDataBuffer dataBuffer, Pointer specialBuffer, long numBytes);
void dbSyncToSpecial(OpaqueDataBuffer dataBuffer);
void dbSyncToPrimary(OpaqueDataBuffer dataBuffer);
void dbTickHostRead(OpaqueDataBuffer dataBuffer);
void dbTickHostWrite(OpaqueDataBuffer dataBuffer);
void dbTickDeviceRead(OpaqueDataBuffer dataBuffer);
void dbTickDeviceWrite(OpaqueDataBuffer dataBuffer);
void deleteDataBuffer(OpaqueDataBuffer dataBuffer);
void dbClose(OpaqueDataBuffer dataBuffer);
int dbLocality(OpaqueDataBuffer dataBuffer);
int dbDeviceId(OpaqueDataBuffer dataBuffer);
void dbSetDeviceId(OpaqueDataBuffer dataBuffer, int deviceId);
void dbExpand(OpaqueDataBuffer dataBuffer, long newLength);
} }

View File

@ -0,0 +1,206 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.nativeblas;
import lombok.NonNull;
import lombok.val;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.buffer.DataType;
/**
* This class is a opaque pointer to InteropDataBuffer, used for Java/C++ interop related to INDArray DataBuffer
*
* @author saudet
*/
public class OpaqueDataBuffer extends Pointer {
// TODO: make this configurable
private static final int MAX_TRIES = 3;
public OpaqueDataBuffer(Pointer p) { super(p); }
/**
* This method allocates new InteropDataBuffer and returns pointer to it
* @param numElements
* @param dataType
* @param allocateBoth
* @return
*/
public static OpaqueDataBuffer allocateDataBuffer(long numElements, @NonNull DataType dataType, boolean allocateBoth) {
OpaqueDataBuffer buffer = null;
int ec = 0;
String em = null;
for (int t = 0; t < MAX_TRIES; t++) {
try {
// try to allocate data buffer
buffer = NativeOpsHolder.getInstance().getDeviceNativeOps().allocateDataBuffer(numElements, dataType.toInt(), allocateBoth);
// check error code
ec = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode();
if (ec != 0) {
if (em == null)
em = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage();
// if allocation failed it might be caused by casual OOM, so we'll try GC
System.gc();
} else {
// just return the buffer
return buffer;
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
// if MAX_TRIES is over, we'll just throw an exception
throw new RuntimeException("Allocation failed: [" + em + "]");
}
/**
* This method expands buffer, and copies content to the new buffer
*
* PLEASE NOTE: if InteropDataBuffer doesn't own actual buffers - original pointers won't be released
* @param numElements
*/
public void expand(long numElements) {
int ec = 0;
String em = null;
for (int t = 0; t < MAX_TRIES; t++) {
try {
// try to expand the buffer
NativeOpsHolder.getInstance().getDeviceNativeOps().dbExpand(this, numElements);
// check error code
ec = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode();
if (ec != 0) {
if (em == null)
em = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage();
// if expansion failed it might be caused by casual OOM, so we'll try GC
System.gc();
} else {
// just return
return;
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
// if MAX_TRIES is over, we'll just throw an exception
throw new RuntimeException("DataBuffer expansion failed: [" + em + "]");
}
/**
* This method creates a view out of this InteropDataBuffer
*
* @param bytesLength
* @param bytesOffset
* @return
*/
public OpaqueDataBuffer createView(long bytesLength, long bytesOffset) {
OpaqueDataBuffer buffer = null;
int ec = 0;
String em = null;
for (int t = 0; t < MAX_TRIES; t++) {
try {
buffer = NativeOpsHolder.getInstance().getDeviceNativeOps().dbCreateView(this, bytesLength, bytesOffset);
// check error code
ec = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode();
if (ec != 0) {
if (em == null)
em = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage();
// if view creation failed it might be caused by casual OOM, so we'll try GC
System.gc();
} else {
// just return
return buffer;
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
// if MAX_TRIES is over, we'll just throw an exception
throw new RuntimeException("DataBuffer expansion failed: [" + em + "]");
}
/**
* This method returns pointer to linear buffer, primary one.
* @return
*/
public Pointer primaryBuffer() {
return NativeOpsHolder.getInstance().getDeviceNativeOps().dbPrimaryBuffer(this);
}
/**
* This method returns pointer to special buffer, device one, if any.
* @return
*/
public Pointer specialBuffer() {
return NativeOpsHolder.getInstance().getDeviceNativeOps().dbSpecialBuffer(this);
}
/**
* This method returns deviceId of this DataBuffer
* @return
*/
public int deviceId() {
return NativeOpsHolder.getInstance().getDeviceNativeOps().dbDeviceId(this);
}
/**
* This method allows to set external pointer as primary buffer.
*
* PLEASE NOTE: if InteropDataBuffer owns current memory buffer, it will be released
* @param ptr
* @param numElements
*/
public void setPrimaryBuffer(Pointer ptr, long numElements) {
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetPrimaryBuffer(this, ptr, numElements);
}
/**
* This method allows to set external pointer as primary buffer.
*
* PLEASE NOTE: if InteropDataBuffer owns current memory buffer, it will be released
* @param ptr
* @param numElements
*/
public void setSpecialBuffer(Pointer ptr, long numElements) {
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetSpecialBuffer(this, ptr, numElements);
}
/**
* This method synchronizes device memory
*/
public void syncToSpecial() {
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToSpecial(this);
}
/**
* This method synchronizes host memory
*/
public void syncToPrimary() {
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToPrimary(this);
}
}

View File

@ -253,6 +253,7 @@
<version>${cuda.version}-${cudnn.version}-${javacpp-presets.cuda.version}</version> <version>${cuda.version}-${cudnn.version}-${javacpp-presets.cuda.version}</version>
<classifier>${dependency.platform}</classifier> <classifier>${dependency.platform}</classifier>
</dependency> </dependency>
<!--
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>
<artifactId>libnd4j</artifactId> <artifactId>libnd4j</artifactId>
@ -261,6 +262,7 @@
<classifier>${javacpp.platform}-cuda-${cuda.version}</classifier> <classifier>${javacpp.platform}-cuda-${cuda.version}</classifier>
<scope>provided</scope> <scope>provided</scope>
</dependency> </dependency>
-->
<dependency> <dependency>
<groupId>junit</groupId> <groupId>junit</groupId>
<artifactId>junit</artifactId> <artifactId>junit</artifactId>

View File

@ -19,6 +19,7 @@ package org.nd4j.jita.allocator.impl;
import lombok.Getter; import lombok.Getter;
import lombok.NonNull; import lombok.NonNull;
import lombok.Setter; import lombok.Setter;
import lombok.val;
import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.enums.AllocationStatus; import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.garbage.GarbageBufferReference; import org.nd4j.jita.allocator.garbage.GarbageBufferReference;
@ -29,9 +30,11 @@ import org.nd4j.jita.allocator.time.providers.MillisecondsProvider;
import org.nd4j.jita.allocator.time.providers.OperativeProvider; import org.nd4j.jita.allocator.time.providers.OperativeProvider;
import org.nd4j.linalg.api.buffer.BaseDataBuffer; import org.nd4j.linalg.api.buffer.BaseDataBuffer;
import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext; import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOps; import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder; import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.nativeblas.OpaqueDataBuffer;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -54,8 +57,8 @@ import java.util.concurrent.locks.ReentrantLock;
public class AllocationPoint { public class AllocationPoint {
private static Logger log = LoggerFactory.getLogger(AllocationPoint.class); private static Logger log = LoggerFactory.getLogger(AllocationPoint.class);
// thread safety is guaranteed by cudaLock @Getter
private volatile PointersPair pointerInfo; private OpaqueDataBuffer ptrDataBuffer;
@Getter @Getter
@Setter @Setter
@ -104,33 +107,27 @@ public class AllocationPoint {
*/ */
private volatile int deviceId; private volatile int deviceId;
public AllocationPoint() { private long bytes;
//
public AllocationPoint(@NonNull OpaqueDataBuffer opaqueDataBuffer, long bytes) {
ptrDataBuffer = opaqueDataBuffer;
this.bytes = bytes;
objectId = Nd4j.getDeallocatorService().nextValue();
} }
public void acquireLock() { public void setPointers(Pointer primary, Pointer special, long numberOfElements) {
//lock.lock(); NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetPrimaryBuffer(ptrDataBuffer, primary, numberOfElements);
} NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetSpecialBuffer(ptrDataBuffer, special, numberOfElements);
public void releaseLock() {
//lock.unlock();
} }
public int getDeviceId() { public int getDeviceId() {
return deviceId; return ptrDataBuffer.deviceId();
} }
public void setDeviceId(int deviceId) { public void setDeviceId(int deviceId) {
this.deviceId = deviceId; NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetDeviceId(ptrDataBuffer, deviceId);
} }
/*
We assume 1D memory chunk allocations.
*/
@Getter
@Setter
private AllocationShape shape;
private AtomicBoolean enqueued = new AtomicBoolean(false); private AtomicBoolean enqueued = new AtomicBoolean(false);
@Getter @Getter
@ -164,7 +161,7 @@ public class AllocationPoint {
} }
public long getNumberOfBytes() { public long getNumberOfBytes() {
return shape.getNumberOfBytes(); return bytes;
} }
/* /*
@ -220,67 +217,25 @@ public class AllocationPoint {
* This method returns CUDA pointer object for this allocation. * This method returns CUDA pointer object for this allocation.
* It can be either device pointer or pinned memory pointer, or null. * It can be either device pointer or pinned memory pointer, or null.
* *
* PLEASE NOTE: Thread safety is guaranteed by reentrant read/write lock
* @return * @return
*/ */
public Pointer getDevicePointer() { public Pointer getDevicePointer() {
if (pointerInfo == null) { return NativeOpsHolder.getInstance().getDeviceNativeOps().dbSpecialBuffer(ptrDataBuffer);
log.info("pointerInfo is null");
return null;
}
return pointerInfo.getDevicePointer();
} }
/** /**
* This method returns CUDA pointer object for this allocation. * This method returns CUDA pointer object for this allocation.
* It can be either device pointer or pinned memory pointer, or null. * It can be either device pointer or pinned memory pointer, or null.
* *
* PLEASE NOTE: Thread safety is guaranteed by reentrant read/write lock
* @return * @return
*/ */
public Pointer getHostPointer() { public Pointer getHostPointer() {
if (pointerInfo == null) return NativeOpsHolder.getInstance().getDeviceNativeOps().dbPrimaryBuffer(ptrDataBuffer);
return null;
return pointerInfo.getHostPointer();
}
/**
* This method sets CUDA pointer for this allocation.
* It can be either device pointer, or pinned memory pointer, or null.
*
* PLEASE NOTE: Thread safety is guaranteed by reentrant read/write lock
* @param pointerInfo CUDA pointers wrapped into DevicePointerInfo
*/
public void setPointers(@NonNull PointersPair pointerInfo) {
this.pointerInfo = pointerInfo;
}
public PointersPair getPointers() {
return this.pointerInfo;
} }
public synchronized void tickDeviceRead() { public synchronized void tickDeviceRead() {
// this.deviceTicks.incrementAndGet(); NativeOpsHolder.getInstance().getDeviceNativeOps().dbTickDeviceRead(ptrDataBuffer);
// this.timerShort.triggerEvent();
// this.timerLong.triggerEvent();
//this.deviceAccessTime.set(realTimeProvider.getCurrentTime());
this.accessDeviceRead = (timeProvider.getCurrentTime());
}
/**
* Returns time, in milliseconds, when this point was accessed on host side
*
* @return
*/
public synchronized long getHostReadTime() {
return accessHostRead;
};
public synchronized long getHostWriteTime() {
return accessHostWrite;
} }
/** /**
@ -302,7 +257,7 @@ public class AllocationPoint {
} }
public synchronized void tickHostRead() { public synchronized void tickHostRead() {
accessHostRead = (timeProvider.getCurrentTime()); NativeOpsHolder.getInstance().getDeviceNativeOps().dbTickHostRead(ptrDataBuffer);
} }
/** /**
@ -310,17 +265,14 @@ public class AllocationPoint {
* *
*/ */
public synchronized void tickDeviceWrite() { public synchronized void tickDeviceWrite() {
// deviceAccessTime.set(realTimeProvider.getCurrentTime()); NativeOpsHolder.getInstance().getDeviceNativeOps().dbTickDeviceWrite(ptrDataBuffer);
tickDeviceRead();
accessDeviceWrite = (timeProvider.getCurrentTime());
} }
/** /**
* This method sets time when this point was changed on host * This method sets time when this point was changed on host
*/ */
public synchronized void tickHostWrite() { public synchronized void tickHostWrite() {
tickHostRead(); NativeOpsHolder.getInstance().getDeviceNativeOps().dbTickHostWrite(ptrDataBuffer);
accessHostWrite = (timeProvider.getCurrentTime());
} }
/** /**
@ -329,10 +281,8 @@ public class AllocationPoint {
* @return true, if data is actual, false otherwise * @return true, if data is actual, false otherwise
*/ */
public synchronized boolean isActualOnHostSide() { public synchronized boolean isActualOnHostSide() {
boolean result = accessHostWrite >= accessDeviceWrite val s = NativeOpsHolder.getInstance().getDeviceNativeOps().dbLocality(ptrDataBuffer);
|| accessHostRead >= accessDeviceWrite; return s <= 0;
return result;
} }
/** /**
@ -341,9 +291,8 @@ public class AllocationPoint {
* @return * @return
*/ */
public synchronized boolean isActualOnDeviceSide() { public synchronized boolean isActualOnDeviceSide() {
boolean result = accessDeviceWrite >= accessHostWrite val s = NativeOpsHolder.getInstance().getDeviceNativeOps().dbLocality(ptrDataBuffer);
|| accessDeviceRead >= accessHostWrite; return s >= 0;
return result;
} }
/** /**
@ -355,6 +304,6 @@ public class AllocationPoint {
@Override @Override
public String toString() { public String toString() {
return "AllocationPoint{" + "deviceId=" + deviceId + ", objectId=" + objectId + ", shape=" + shape + '}'; return "AllocationPoint{" + "deviceId=" + deviceId + ", objectId=" + objectId + "}";
} }
} }

View File

@ -19,12 +19,10 @@ package org.nd4j.jita.allocator.impl;
import lombok.Getter; import lombok.Getter;
import lombok.NonNull; import lombok.NonNull;
import lombok.val; import lombok.val;
import org.apache.commons.lang3.RandomUtils;
import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.Allocator; import org.nd4j.jita.allocator.Allocator;
import org.nd4j.jita.allocator.enums.Aggressiveness; import org.nd4j.jita.allocator.enums.Aggressiveness;
import org.nd4j.jita.allocator.enums.AllocationStatus; import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.garbage.GarbageBufferReference;
import org.nd4j.jita.allocator.pointers.CudaPointer; import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.PointersPair; import org.nd4j.jita.allocator.pointers.PointersPair;
import org.nd4j.jita.allocator.time.Ring; import org.nd4j.jita.allocator.time.Ring;
@ -37,29 +35,25 @@ import org.nd4j.jita.flow.FlowController;
import org.nd4j.jita.handler.MemoryHandler; import org.nd4j.jita.handler.MemoryHandler;
import org.nd4j.jita.handler.impl.CudaZeroHandler; import org.nd4j.jita.handler.impl.CudaZeroHandler;
import org.nd4j.jita.workspace.CudaWorkspace; import org.nd4j.jita.workspace.CudaWorkspace;
import org.nd4j.linalg.api.buffer.BaseDataBuffer;
import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.Utf8Buffer;
import org.nd4j.linalg.api.memory.enums.MemoryKind; import org.nd4j.linalg.api.memory.enums.MemoryKind;
import org.nd4j.linalg.api.memory.pointers.PagedPointer;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.cache.ConstantHandler; import org.nd4j.linalg.cache.ConstantHandler;
import org.nd4j.linalg.compression.CompressedDataBuffer; import org.nd4j.linalg.compression.CompressedDataBuffer;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaUtf8Buffer;
import org.nd4j.linalg.jcublas.context.CudaContext; import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOpsHolder; import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer;
import java.lang.ref.ReferenceQueue;
import java.util.Map; import java.util.Map;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.LockSupport;
import java.util.concurrent.locks.ReentrantReadWriteLock; import java.util.concurrent.locks.ReentrantReadWriteLock;
/** /**
@ -285,16 +279,10 @@ public class AtomicAllocator implements Allocator {
*/ */
@Override @Override
public Pointer getPointer(@NonNull DataBuffer buffer, CudaContext context) { public Pointer getPointer(@NonNull DataBuffer buffer, CudaContext context) {
if (buffer instanceof Utf8Buffer)
return null;
return memoryHandler.getDevicePointer(buffer, context); return memoryHandler.getDevicePointer(buffer, context);
} }
public Pointer getPointer(DataBuffer buffer) { public Pointer getPointer(DataBuffer buffer) {
if (buffer instanceof Utf8Buffer)
return null;
return memoryHandler.getDevicePointer(buffer, getDeviceContext()); return memoryHandler.getDevicePointer(buffer, getDeviceContext());
} }
@ -320,7 +308,7 @@ public class AtomicAllocator implements Allocator {
public Pointer getPointer(INDArray array, CudaContext context) { public Pointer getPointer(INDArray array, CudaContext context) {
// DataBuffer buffer = array.data().originalDataBuffer() == null ? array.data() : array.data().originalDataBuffer(); // DataBuffer buffer = array.data().originalDataBuffer() == null ? array.data() : array.data().originalDataBuffer();
if (array.isEmpty() || array.isS()) if (array.isEmpty() || array.isS())
return null; throw new UnsupportedOperationException("Pew-pew");
return memoryHandler.getDevicePointer(array.data(), context); return memoryHandler.getDevicePointer(array.data(), context);
} }
@ -372,20 +360,17 @@ public class AtomicAllocator implements Allocator {
@Override @Override
public void synchronizeHostData(DataBuffer buffer) { public void synchronizeHostData(DataBuffer buffer) {
// we don't want non-committed ops left behind // we don't want non-committed ops left behind
//Nd4j.getExecutioner().push(); Nd4j.getExecutioner().commit();
// we don't synchronize constant buffers, since we assume they are always valid on host side val oPtr = NativeOpsHolder.getInstance().getDeviceNativeOps().dbPrimaryBuffer(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer());
if (buffer.isConstant() || buffer.dataType() == DataType.UTF8 || AtomicAllocator.getInstance().getAllocationPoint(buffer).getPointers().getHostPointer() == null) {
return;
}
// we actually need synchronization only in device-dependant environment. no-op otherwise // we actually need synchronization only in device-dependant environment. no-op otherwise. managed by native code
if (memoryHandler.isDeviceDependant()) { NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToPrimary(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer());
val point = getAllocationPoint(buffer.getTrackingPoint());
if (point == null) val cPtr = NativeOpsHolder.getInstance().getDeviceNativeOps().dbPrimaryBuffer(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer());
throw new RuntimeException("AllocationPoint is NULL");
memoryHandler.synchronizeThreadDevice(Thread.currentThread().getId(), memoryHandler.getDeviceId(), point); //assert oPtr.address() == cPtr.address();
} //assert buffer.address() == oPtr.address();
} }
@ -446,6 +431,7 @@ public class AtomicAllocator implements Allocator {
public AllocationPoint pickExternalBuffer(DataBuffer buffer) { public AllocationPoint pickExternalBuffer(DataBuffer buffer) {
/**
AllocationPoint point = new AllocationPoint(); AllocationPoint point = new AllocationPoint();
Long allocId = objectsTracker.getAndIncrement(); Long allocId = objectsTracker.getAndIncrement();
point.setObjectId(allocId); point.setObjectId(allocId);
@ -458,6 +444,9 @@ public class AtomicAllocator implements Allocator {
point.tickHostRead(); point.tickHostRead();
return point; return point;
*/
throw new UnsupportedOperationException("Pew-pew");
} }
/** /**
@ -469,69 +458,8 @@ public class AtomicAllocator implements Allocator {
* @param location * @param location
*/ */
@Override @Override
public AllocationPoint allocateMemory(DataBuffer buffer, AllocationShape requiredMemory, AllocationStatus location, public AllocationPoint allocateMemory(DataBuffer buffer, AllocationShape requiredMemory, AllocationStatus location, boolean initialize) {
boolean initialize) { throw new UnsupportedOperationException("Pew-pew");
AllocationPoint point = new AllocationPoint();
useTracker.set(System.currentTimeMillis());
// we use these longs as tracking codes for memory tracking
Long allocId = objectsTracker.getAndIncrement();
//point.attachBuffer(buffer);
point.setObjectId(allocId);
point.setShape(requiredMemory);
/*
if (buffer instanceof CudaIntDataBuffer) {
buffer.setConstant(true);
point.setConstant(true);
}
*/
/*int numBuckets = configuration.getNumberOfGcThreads();
int bucketId = RandomUtils.nextInt(0, numBuckets);
GarbageBufferReference reference =
new GarbageBufferReference((BaseDataBuffer) buffer, queueMap.get(bucketId), point);*/
//point.attachReference(reference);
point.setDeviceId(-1);
if (buffer.isAttached()) {
long reqMem = AllocationUtils.getRequiredMemory(requiredMemory);
// workaround for init order
getMemoryHandler().getCudaContext();
point.setDeviceId(Nd4j.getAffinityManager().getDeviceForCurrentThread());
val workspace = (CudaWorkspace) Nd4j.getMemoryManager().getCurrentWorkspace();
val pair = new PointersPair();
val ptrDev = workspace.alloc(reqMem, MemoryKind.DEVICE, requiredMemory.getDataType(), initialize);
if (ptrDev != null) {
pair.setDevicePointer(ptrDev);
point.setAllocationStatus(AllocationStatus.DEVICE);
} else {
// we allocate initial host pointer only
val ptrHost = workspace.alloc(reqMem, MemoryKind.HOST, requiredMemory.getDataType(), initialize);
pair.setHostPointer(ptrHost);
pair.setDevicePointer(ptrHost);
point.setAllocationStatus(AllocationStatus.HOST);
}
point.setAttached(true);
point.setPointers(pair);
} else {
// we stay naive on PointersPair, we just don't know on this level, which pointers are set. MemoryHandler will be used for that
PointersPair pair = memoryHandler.alloc(location, point, requiredMemory, initialize);
point.setPointers(pair);
}
allocationsMap.put(allocId, point);
//point.tickHostRead();
point.tickDeviceWrite();
//point.setAllocationStatus(location);
return point;
} }
@ -619,10 +547,11 @@ public class AtomicAllocator implements Allocator {
*/ */
if (point.getBuffer() == null) { if (point.getBuffer() == null) {
purgeZeroObject(bucketId, object, point, false); purgeZeroObject(bucketId, object, point, false);
freeSpace.addAndGet(AllocationUtils.getRequiredMemory(point.getShape())); //freeSpace.addAndGet(AllocationUtils.getRequiredMemory(point.getShape()));
throw new UnsupportedOperationException("Pew-pew");
elementsDropped.incrementAndGet(); //elementsDropped.incrementAndGet();
continue; //continue;
} else { } else {
elementsSurvived.incrementAndGet(); elementsSurvived.incrementAndGet();
} }
@ -682,13 +611,14 @@ public class AtomicAllocator implements Allocator {
if (point.getAllocationStatus() == AllocationStatus.DEVICE) { if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
// we deallocate device memory // we deallocate device memory
purgeDeviceObject(threadId, deviceId, object, point, false); purgeDeviceObject(threadId, deviceId, object, point, false);
freeSpace.addAndGet(AllocationUtils.getRequiredMemory(point.getShape())); //freeSpace.addAndGet(AllocationUtils.getRequiredMemory(point.getShape()));
// and we deallocate host memory, since object is dereferenced // and we deallocate host memory, since object is dereferenced
purgeZeroObject(point.getBucketId(), object, point, false); //purgeZeroObject(point.getBucketId(), object, point, false);
elementsDropped.incrementAndGet(); //elementsDropped.incrementAndGet();
continue; //continue;
throw new UnsupportedOperationException("Pew-pew");
} ; } ;
} else { } else {
elementsSurvived.incrementAndGet(); elementsSurvived.incrementAndGet();
@ -1014,6 +944,31 @@ public class AtomicAllocator implements Allocator {
this.memoryHandler.memcpy(dstBuffer, srcBuffer); this.memoryHandler.memcpy(dstBuffer, srcBuffer);
} }
@Override
public void tickHostWrite(DataBuffer buffer) {
getAllocationPoint(buffer).tickHostWrite();
}
@Override
public void tickHostWrite(INDArray array) {
getAllocationPoint(array.data()).tickHostWrite();
}
@Override
public void tickDeviceWrite(INDArray array) {
getAllocationPoint(array.data()).tickDeviceWrite();
}
@Override
public AllocationPoint getAllocationPoint(INDArray array) {
return getAllocationPoint(array.data());
}
@Override
public AllocationPoint getAllocationPoint(DataBuffer buffer) {
return ((BaseCudaDataBuffer) buffer).getAllocationPoint();
}
/** /**
* This method returns deviceId for current thread * This method returns deviceId for current thread
* All values >= 0 are considered valid device IDs, all values < 0 are considered stubs. * All values >= 0 are considered valid device IDs, all values < 0 are considered stubs.
@ -1031,48 +986,6 @@ public class AtomicAllocator implements Allocator {
return new CudaPointer(getDeviceId()); return new CudaPointer(getDeviceId());
} }
@Override
public void tickHostWrite(DataBuffer buffer) {
AllocationPoint point = getAllocationPoint(buffer.getTrackingPoint());
point.tickHostWrite();
}
@Override
public void tickHostWrite(INDArray array) {
DataBuffer buffer =
array.data().originalDataBuffer() == null ? array.data() : array.data().originalDataBuffer();
tickHostWrite(buffer);
}
@Override
public void tickDeviceWrite(INDArray array) {
DataBuffer buffer =
array.data().originalDataBuffer() == null ? array.data() : array.data().originalDataBuffer();
AllocationPoint point = getAllocationPoint(buffer.getTrackingPoint());
point.tickDeviceWrite();
}
@Override
public AllocationPoint getAllocationPoint(INDArray array) {
if (array.isEmpty())
return null;
DataBuffer buffer = array.data().originalDataBuffer() == null ? array.data() : array.data().originalDataBuffer();
return getAllocationPoint(buffer);
}
@Override
public AllocationPoint getAllocationPoint(DataBuffer buffer) {
if (buffer instanceof CompressedDataBuffer) {
log.warn("Trying to get AllocationPoint from CompressedDataBuffer");
throw new RuntimeException("AP CDB");
}
return getAllocationPoint(buffer.getTrackingPoint());
}
@Override @Override
public void registerAction(CudaContext context, INDArray result, INDArray... operands) { public void registerAction(CudaContext context, INDArray result, INDArray... operands) {
memoryHandler.registerAction(context, result, operands); memoryHandler.registerAction(context, result, operands);

View File

@ -23,46 +23,21 @@ import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer;
import java.util.Map; import java.util.Map;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import org.nd4j.linalg.api.memory.Deallocator; import org.nd4j.linalg.api.memory.Deallocator;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.nativeblas.OpaqueDataBuffer;
@Slf4j @Slf4j
public class CudaDeallocator implements Deallocator { public class CudaDeallocator implements Deallocator {
private AllocationPoint point; private OpaqueDataBuffer opaqueDataBuffer;
public CudaDeallocator(@NonNull BaseCudaDataBuffer buffer) { public CudaDeallocator(@NonNull BaseCudaDataBuffer buffer) {
this.point = buffer.getAllocationPoint(); opaqueDataBuffer = buffer.getOpaqueDataBuffer();
if (this.point == null)
throw new RuntimeException();
} }
@Override @Override
public void deallocate() { public void deallocate() {
log.trace("Deallocating CUDA memory"); log.trace("Deallocating CUDA memory");
// skipping any allocation that is coming from workspace NativeOpsHolder.getInstance().getDeviceNativeOps().deleteDataBuffer(opaqueDataBuffer);
if (point.isAttached() || point.isReleased()) {
// TODO: remove allocation point as well?
if (!AtomicAllocator.getInstance().allocationsMap().containsKey(point.getObjectId()))
return;
AtomicAllocator.getInstance().getFlowController().waitTillReleased(point);
AtomicAllocator.getInstance().getFlowController().getEventsProvider().storeEvent(point.getLastWriteEvent());
AtomicAllocator.getInstance().getFlowController().getEventsProvider().storeEvent(point.getLastReadEvent());
AtomicAllocator.getInstance().allocationsMap().remove(point.getObjectId());
return;
}
//log.info("Purging {} bytes...", AllocationUtils.getRequiredMemory(point.getShape()));
if (point.getAllocationStatus() == AllocationStatus.HOST) {
AtomicAllocator.getInstance().purgeZeroObject(point.getBucketId(), point.getObjectId(), point, false);
} else if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
AtomicAllocator.getInstance().purgeDeviceObject(0L, point.getDeviceId(), point.getObjectId(), point, false);
// and we deallocate host memory, since object is dereferenced
AtomicAllocator.getInstance().purgeZeroObject(point.getBucketId(), point.getObjectId(), point, false);
}
} }
} }

View File

@ -17,6 +17,7 @@
package org.nd4j.jita.allocator.pointers.cuda; package org.nd4j.jita.allocator.pointers.cuda;
import lombok.NonNull; import lombok.NonNull;
import lombok.val;
import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.pointers.CudaPointer; import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.linalg.exception.ND4JException; import org.nd4j.linalg.exception.ND4JException;
@ -37,8 +38,9 @@ public class cudaStream_t extends CudaPointer {
NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
int res = nativeOps.streamSynchronize(this); int res = nativeOps.streamSynchronize(this);
if (nativeOps.lastErrorCode() != 0) val ec = nativeOps.lastErrorCode();
throw new RuntimeException(nativeOps.lastErrorMessage()); if (ec != 0)
throw new RuntimeException(nativeOps.lastErrorMessage() + "; Error code: " + ec);
return res; return res;
} }

View File

@ -129,7 +129,7 @@ public class ProtectedCudaConstantHandler implements ConstantHandler {
AllocationPoint point = AtomicAllocator.getInstance().getAllocationPoint(dataBuffer); AllocationPoint point = AtomicAllocator.getInstance().getAllocationPoint(dataBuffer);
long requiredMemoryBytes = AllocationUtils.getRequiredMemory(point.getShape()); long requiredMemoryBytes = point.getNumberOfBytes();
val originalBytes = requiredMemoryBytes; val originalBytes = requiredMemoryBytes;
requiredMemoryBytes += 8 - (requiredMemoryBytes % 8); requiredMemoryBytes += 8 - (requiredMemoryBytes % 8);
@ -147,13 +147,13 @@ public class ProtectedCudaConstantHandler implements ConstantHandler {
if (currentOffset + requiredMemoryBytes >= MAX_CONSTANT_LENGTH || requiredMemoryBytes > MAX_BUFFER_LENGTH) { if (currentOffset + requiredMemoryBytes >= MAX_CONSTANT_LENGTH || requiredMemoryBytes > MAX_BUFFER_LENGTH) {
if (point.getAllocationStatus() == AllocationStatus.HOST if (point.getAllocationStatus() == AllocationStatus.HOST
&& CudaEnvironment.getInstance().getConfiguration().getMemoryModel() == Configuration.MemoryModel.DELAYED) { && CudaEnvironment.getInstance().getConfiguration().getMemoryModel() == Configuration.MemoryModel.DELAYED) {
AtomicAllocator.getInstance().getMemoryHandler().alloc(AllocationStatus.DEVICE, point, point.getShape(), //AtomicAllocator.getInstance().getMemoryHandler().alloc(AllocationStatus.DEVICE, point, point.getShape(), false);
false); throw new UnsupportedOperationException("Pew-pew");
} }
val profD = PerformanceTracker.getInstance().helperStartTransaction(); val profD = PerformanceTracker.getInstance().helperStartTransaction();
if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(point.getPointers().getDevicePointer(), point.getPointers().getHostPointer(), originalBytes, 1, context.getSpecialStream()) == 0) { if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(point.getDevicePointer(), point.getHostPointer(), originalBytes, 1, context.getSpecialStream()) == 0) {
throw new ND4JIllegalStateException("memcpyAsync failed"); throw new ND4JIllegalStateException("memcpyAsync failed");
} }
flowController.commitTransfer(context.getSpecialStream()); flowController.commitTransfer(context.getSpecialStream());
@ -176,14 +176,13 @@ public class ProtectedCudaConstantHandler implements ConstantHandler {
if (currentOffset >= MAX_CONSTANT_LENGTH) { if (currentOffset >= MAX_CONSTANT_LENGTH) {
if (point.getAllocationStatus() == AllocationStatus.HOST if (point.getAllocationStatus() == AllocationStatus.HOST
&& CudaEnvironment.getInstance().getConfiguration().getMemoryModel() == Configuration.MemoryModel.DELAYED) { && CudaEnvironment.getInstance().getConfiguration().getMemoryModel() == Configuration.MemoryModel.DELAYED) {
AtomicAllocator.getInstance().getMemoryHandler().alloc(AllocationStatus.DEVICE, point, point.getShape(), //AtomicAllocator.getInstance().getMemoryHandler().alloc(AllocationStatus.DEVICE, point, point.getShape(), false);
false); throw new UnsupportedOperationException("Pew-pew");
} }
val profD = PerformanceTracker.getInstance().helperStartTransaction(); val profD = PerformanceTracker.getInstance().helperStartTransaction();
if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(point.getPointers().getDevicePointer(), point.getPointers().getHostPointer(), if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(point.getDevicePointer(), point.getHostPointer(), originalBytes, 1, context.getSpecialStream()) == 0) {
originalBytes, 1, context.getSpecialStream()) == 0) {
throw new ND4JIllegalStateException("memcpyAsync failed"); throw new ND4JIllegalStateException("memcpyAsync failed");
} }
flowController.commitTransfer(context.getSpecialStream()); flowController.commitTransfer(context.getSpecialStream());
@ -202,8 +201,7 @@ public class ProtectedCudaConstantHandler implements ConstantHandler {
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyConstantAsync(currentOffset, point.getPointers().getHostPointer(), originalBytes, 1, NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyConstantAsync(currentOffset, point.getHostPointer(), originalBytes, 1, context.getSpecialStream());
context.getSpecialStream());
flowController.commitTransfer(context.getSpecialStream()); flowController.commitTransfer(context.getSpecialStream());
long cAddr = deviceAddresses.get(deviceId).address() + currentOffset; long cAddr = deviceAddresses.get(deviceId).address() + currentOffset;
@ -212,7 +210,10 @@ public class ProtectedCudaConstantHandler implements ConstantHandler {
// logger.info("copying to constant: {}, bufferLength: {}, bufferDtype: {}, currentOffset: {}, currentAddres: {}", requiredMemoryBytes, dataBuffer.length(), dataBuffer.dataType(), currentOffset, cAddr); // logger.info("copying to constant: {}, bufferLength: {}, bufferDtype: {}, currentOffset: {}, currentAddres: {}", requiredMemoryBytes, dataBuffer.length(), dataBuffer.dataType(), currentOffset, cAddr);
point.setAllocationStatus(AllocationStatus.CONSTANT); point.setAllocationStatus(AllocationStatus.CONSTANT);
point.getPointers().setDevicePointer(new CudaPointer(cAddr)); //point.setDevicePointer(new CudaPointer(cAddr));
if (1 > 0)
throw new UnsupportedOperationException("Pew-pew");
point.setConstant(true); point.setConstant(true);
point.tickDeviceWrite(); point.tickDeviceWrite();
point.setDeviceId(deviceId); point.setDeviceId(deviceId);

View File

@ -32,6 +32,7 @@ import org.nd4j.jita.conf.Configuration;
import org.nd4j.jita.conf.CudaEnvironment; import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.jita.flow.FlowController; import org.nd4j.jita.flow.FlowController;
import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.performance.PerformanceTracker; import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
@ -70,53 +71,12 @@ public class SynchronousFlowController implements FlowController {
*/ */
@Override @Override
public void synchronizeToHost(AllocationPoint point) { public void synchronizeToHost(AllocationPoint point) {
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToPrimary(point.getPtrDataBuffer());
if (!point.isActualOnHostSide()) {
val context = allocator.getDeviceContext();
if (!point.isConstant())
waitTillFinished(point);
// if this piece of memory is device-dependant, we'll also issue copyback once
if (point.getAllocationStatus() == AllocationStatus.DEVICE && !point.isActualOnHostSide()) {
long perfD = PerformanceTracker.getInstance().helperStartTransaction();
val bytes = AllocationUtils.getRequiredMemory(point.getShape());
if (nativeOps.memcpyAsync(point.getHostPointer(), point.getDevicePointer(), bytes, CudaConstants.cudaMemcpyDeviceToHost, context.getSpecialStream()) == 0)
throw new IllegalStateException("synchronizeToHost memcpyAsync failed: " + point.getShape());
commitTransfer(context.getSpecialStream());
PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), perfD, point.getNumberOfBytes(), MemcpyDirection.DEVICE_TO_HOST);
}
// updating host read timer
point.tickHostRead();
}
} }
@Override @Override
public void synchronizeToDevice(@NonNull AllocationPoint point) { public void synchronizeToDevice(@NonNull AllocationPoint point) {
if (point.isConstant()) NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToSpecial(point.getPtrDataBuffer());
return;
if (!point.isActualOnDeviceSide()) {
if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
val context = allocator.getDeviceContext();
long perfD = PerformanceTracker.getInstance().helperStartTransaction();
if (nativeOps.memcpyAsync(point.getDevicePointer(), point.getHostPointer(),
AllocationUtils.getRequiredMemory(point.getShape()),
CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream()) == 0)
throw new IllegalStateException("MemcpyAsync failed: " + point.getShape());
commitTransfer(context.getSpecialStream());
point.tickDeviceRead();
PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), perfD, point.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);
}
}
} }
@Override @Override
@ -147,7 +107,6 @@ public class SynchronousFlowController implements FlowController {
val pointData = allocator.getAllocationPoint(operand); val pointData = allocator.getAllocationPoint(operand);
val pointShape = allocator.getAllocationPoint(operand.shapeInfoDataBuffer()); val pointShape = allocator.getAllocationPoint(operand.shapeInfoDataBuffer());
pointData.acquireLock();
if (pointData.getDeviceId() != cId && pointData.getDeviceId() >= 0) { if (pointData.getDeviceId() != cId && pointData.getDeviceId() >= 0) {
DataBuffer buffer = operand.data().originalDataBuffer() == null ? operand.data() DataBuffer buffer = operand.data().originalDataBuffer() == null ? operand.data()
@ -172,15 +131,12 @@ public class SynchronousFlowController implements FlowController {
val cId = allocator.getDeviceId(); val cId = allocator.getDeviceId();
if (result != null && !result.isEmpty() && !result.isS()) { if (result != null && !result.isEmpty()) {
Nd4j.getCompressor().autoDecompress(result); Nd4j.getCompressor().autoDecompress(result);
prepareDelayedMemory(result); prepareDelayedMemory(result);
val pointData = allocator.getAllocationPoint(result); val pointData = allocator.getAllocationPoint(result);
val pointShape = allocator.getAllocationPoint(result.shapeInfoDataBuffer()); val pointShape = allocator.getAllocationPoint(result.shapeInfoDataBuffer());
pointData.acquireLock();
if (pointData.getDeviceId() != cId && pointData.getDeviceId() >= 0 && (!CudaEnvironment.getInstance().getConfiguration().isCrossDeviceAccessAllowed() || !NativeOpsHolder.getInstance().getDeviceNativeOps().isP2PAvailable())) { if (pointData.getDeviceId() != cId && pointData.getDeviceId() >= 0 && (!CudaEnvironment.getInstance().getConfiguration().isCrossDeviceAccessAllowed() || !NativeOpsHolder.getInstance().getDeviceNativeOps().isP2PAvailable())) {
DataBuffer buffer = result.data().originalDataBuffer() == null ? result.data() DataBuffer buffer = result.data().originalDataBuffer() == null ? result.data()
: result.data().originalDataBuffer(); : result.data().originalDataBuffer();
@ -206,8 +162,7 @@ public class SynchronousFlowController implements FlowController {
val pointData = allocator.getAllocationPoint(operand); val pointData = allocator.getAllocationPoint(operand);
val pointShape = allocator.getAllocationPoint(operand.shapeInfoDataBuffer()); val pointShape = allocator.getAllocationPoint(operand.shapeInfoDataBuffer());
Nd4j.getAffinityManager().ensureLocation(operand, AffinityManager.Location.DEVICE);
pointData.acquireLock();
if (pointData.getDeviceId() != cId && pointData.getDeviceId() >= 0 && (!CudaEnvironment.getInstance().getConfiguration().isCrossDeviceAccessAllowed() || !NativeOpsHolder.getInstance().getDeviceNativeOps().isP2PAvailable())) { if (pointData.getDeviceId() != cId && pointData.getDeviceId() >= 0 && (!CudaEnvironment.getInstance().getConfiguration().isCrossDeviceAccessAllowed() || !NativeOpsHolder.getInstance().getDeviceNativeOps().isP2PAvailable())) {
DataBuffer buffer = operand.data().originalDataBuffer() == null ? operand.data() DataBuffer buffer = operand.data().originalDataBuffer() == null ? operand.data()
@ -240,14 +195,12 @@ public class SynchronousFlowController implements FlowController {
eventsProvider.storeEvent(result.getLastWriteEvent()); eventsProvider.storeEvent(result.getLastWriteEvent());
result.setLastWriteEvent(eventsProvider.getEvent()); result.setLastWriteEvent(eventsProvider.getEvent());
result.getLastWriteEvent().register(context.getOldStream()); result.getLastWriteEvent().register(context.getOldStream());
result.releaseLock();
for (AllocationPoint operand : operands) { for (AllocationPoint operand : operands) {
eventsProvider.storeEvent(operand.getLastReadEvent()); eventsProvider.storeEvent(operand.getLastReadEvent());
operand.setLastReadEvent(eventsProvider.getEvent()); operand.setLastReadEvent(eventsProvider.getEvent());
operand.getLastReadEvent().register(context.getOldStream()); operand.getLastReadEvent().register(context.getOldStream());
operand.releaseLock();
} }
// context.syncOldStream(); // context.syncOldStream();
} }
@ -263,7 +216,6 @@ public class SynchronousFlowController implements FlowController {
eventsProvider.storeEvent(pointOperand.getLastWriteEvent()); eventsProvider.storeEvent(pointOperand.getLastWriteEvent());
pointOperand.setLastWriteEvent(eventsProvider.getEvent()); pointOperand.setLastWriteEvent(eventsProvider.getEvent());
pointOperand.getLastWriteEvent().register(context.getOldStream()); pointOperand.getLastWriteEvent().register(context.getOldStream());
pointOperand.releaseLock();
} }
} }
@ -276,14 +228,12 @@ public class SynchronousFlowController implements FlowController {
eventsProvider.storeEvent(point.getLastWriteEvent()); eventsProvider.storeEvent(point.getLastWriteEvent());
point.setLastWriteEvent(eventsProvider.getEvent()); point.setLastWriteEvent(eventsProvider.getEvent());
point.getLastWriteEvent().register(context.getOldStream()); point.getLastWriteEvent().register(context.getOldStream());
point.releaseLock();
for (INDArray operand : operands) { for (INDArray operand : operands) {
if (operand == null || operand.isEmpty()) if (operand == null || operand.isEmpty())
continue; continue;
val pointOperand = allocator.getAllocationPoint(operand); val pointOperand = allocator.getAllocationPoint(operand);
pointOperand.releaseLock();
eventsProvider.storeEvent(pointOperand.getLastReadEvent()); eventsProvider.storeEvent(pointOperand.getLastReadEvent());
pointOperand.setLastReadEvent(eventsProvider.getEvent()); pointOperand.setLastReadEvent(eventsProvider.getEvent());
pointOperand.getLastReadEvent().register(context.getOldStream()); pointOperand.getLastReadEvent().register(context.getOldStream());
@ -295,7 +245,6 @@ public class SynchronousFlowController implements FlowController {
val context = allocator.getDeviceContext(); val context = allocator.getDeviceContext();
if (result != null) { if (result != null) {
result.acquireLock();
result.setCurrentContext(context); result.setCurrentContext(context);
} }
@ -303,7 +252,6 @@ public class SynchronousFlowController implements FlowController {
if (operand == null) if (operand == null)
continue; continue;
operand.acquireLock();
operand.setCurrentContext(context); operand.setCurrentContext(context);
} }

View File

@ -16,6 +16,7 @@
package org.nd4j.jita.handler.impl; package org.nd4j.jita.handler.impl;
import lombok.var;
import org.nd4j.nativeblas.OpaqueLaunchContext; import org.nd4j.nativeblas.OpaqueLaunchContext;
import org.nd4j.shade.guava.collect.HashBasedTable; import org.nd4j.shade.guava.collect.HashBasedTable;
import org.nd4j.shade.guava.collect.Table; import org.nd4j.shade.guava.collect.Table;
@ -44,9 +45,6 @@ import org.nd4j.jita.flow.FlowController;
import org.nd4j.jita.flow.impl.GridFlowController; import org.nd4j.jita.flow.impl.GridFlowController;
import org.nd4j.jita.handler.MemoryHandler; import org.nd4j.jita.handler.MemoryHandler;
import org.nd4j.jita.memory.MemoryProvider; import org.nd4j.jita.memory.MemoryProvider;
import org.nd4j.jita.memory.impl.CudaCachingZeroProvider;
import org.nd4j.jita.memory.impl.CudaDirectProvider;
import org.nd4j.jita.memory.impl.CudaFullCachingProvider;
import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
@ -99,9 +97,6 @@ public class CudaZeroHandler implements MemoryHandler {
private final AtomicBoolean wasInitialised = new AtomicBoolean(false); private final AtomicBoolean wasInitialised = new AtomicBoolean(false);
@Getter
private final MemoryProvider memoryProvider;
private final FlowController flowController; private final FlowController flowController;
private final AllocationStatus INITIAL_LOCATION; private final AllocationStatus INITIAL_LOCATION;
@ -148,20 +143,6 @@ public class CudaZeroHandler implements MemoryHandler {
throw new RuntimeException("Unknown ExecutionModel: [" + configuration.getExecutionModel() + "]"); throw new RuntimeException("Unknown ExecutionModel: [" + configuration.getExecutionModel() + "]");
} }
switch (configuration.getAllocationModel()) {
case CACHE_ALL:
this.memoryProvider = new CudaFullCachingProvider();
break;
case CACHE_HOST:
this.memoryProvider = new CudaCachingZeroProvider();
break;
case DIRECT:
this.memoryProvider = new CudaDirectProvider();
break;
default:
throw new RuntimeException("Unknown AllocationModel: [" + configuration.getAllocationModel() + "]");
}
int numDevices = NativeOpsHolder.getInstance().getDeviceNativeOps().getAvailableDevices(); int numDevices = NativeOpsHolder.getInstance().getDeviceNativeOps().getAvailableDevices();
for (int i = 0; i < numDevices; i++) { for (int i = 0; i < numDevices; i++) {
deviceAllocations.add(new ConcurrentHashMap<Long, Long>()); deviceAllocations.add(new ConcurrentHashMap<Long, Long>());
@ -191,7 +172,7 @@ public class CudaZeroHandler implements MemoryHandler {
int numBuckets = configuration.getNumberOfGcThreads(); int numBuckets = configuration.getNumberOfGcThreads();
long bucketId = RandomUtils.nextInt(0, numBuckets); long bucketId = RandomUtils.nextInt(0, numBuckets);
long reqMemory = AllocationUtils.getRequiredMemory(point.getShape()); long reqMemory = point.getNumberOfBytes();
zeroUseCounter.addAndGet(reqMemory); zeroUseCounter.addAndGet(reqMemory);
@ -221,130 +202,7 @@ public class CudaZeroHandler implements MemoryHandler {
public PointersPair alloc(AllocationStatus targetMode, AllocationPoint point, AllocationShape shape, public PointersPair alloc(AllocationStatus targetMode, AllocationPoint point, AllocationShape shape,
boolean initialize) { boolean initialize) {
long reqMemory = AllocationUtils.getRequiredMemory(shape); throw new UnsupportedOperationException();
val context = getCudaContext();
switch (targetMode) {
case HOST: {
if (MemoryTracker.getInstance().getActiveHostAmount() + reqMemory >= configuration.getMaximumZeroAllocation()) {
while (MemoryTracker.getInstance().getActiveHostAmount() + reqMemory >= configuration.getMaximumZeroAllocation()) {
val before = MemoryTracker.getInstance().getActiveHostAmount();
memoryProvider.purgeCache();
Nd4j.getMemoryManager().invokeGc();
val after = MemoryTracker.getInstance().getActiveHostAmount();
log.debug("[HOST] before: {}; after: {};", before, after);
if (MemoryTracker.getInstance().getActiveHostAmount() + reqMemory >= configuration.getMaximumZeroAllocation()) {
try {
log.warn("No available [HOST] memory, sleeping for a while... Consider increasing -Xmx next time.");
log.debug("Currently used: [" + zeroUseCounter.get() + "], allocated objects: [" + zeroAllocations.get(0) + "]");
memoryProvider.purgeCache();
Nd4j.getMemoryManager().invokeGc();
Thread.sleep(1000);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
}
PointersPair pair = memoryProvider.malloc(shape, point, targetMode);
if (initialize) {
org.bytedeco.javacpp.Pointer.memset(pair.getHostPointer(), 0, reqMemory);
point.tickHostWrite();
}
pickupHostAllocation(point);
return pair;
}
case DEVICE: {
int deviceId = getDeviceId();
PointersPair returnPair = new PointersPair();
PointersPair tmpPair = new PointersPair();
if (point.getPointers() == null)
point.setPointers(tmpPair);
if (deviceMemoryTracker.reserveAllocationIfPossible(Thread.currentThread().getId(), deviceId, reqMemory)) {
point.setDeviceId(deviceId);
val pair = memoryProvider.malloc(shape, point, targetMode);
if (pair != null) {
returnPair.setDevicePointer(pair.getDevicePointer());
point.setAllocationStatus(AllocationStatus.DEVICE);
if (point.getPointers() == null)
throw new RuntimeException("PointersPair can't be null");
point.getPointers().setDevicePointer(pair.getDevicePointer());
deviceAllocations.get(deviceId).put(point.getObjectId(), point.getObjectId());
val p = point.getBucketId();
if (p != null) {
val m = zeroAllocations.get(point.getBucketId());
// m can be null, if that's point from workspace - just no bucketId for it
if (m != null)
m.remove(point.getObjectId());
}
deviceMemoryTracker.addToAllocation(Thread.currentThread().getId(), deviceId, reqMemory);
if (!initialize) {
point.tickDeviceWrite();
} else {
nativeOps.memsetAsync(pair.getDevicePointer(), 0, reqMemory, 0, context.getSpecialStream());
context.getSpecialStream().synchronize();
point.tickDeviceWrite();
}
} else {
log.warn("Out of [DEVICE] memory, host memory will be used instead: deviceId: [{}], requested bytes: [{}]; Approximate free bytes: {}; Real free bytes: {}", deviceId, reqMemory, MemoryTracker.getInstance().getApproximateFreeMemory(deviceId), MemoryTracker.getInstance().getPreciseFreeMemory(deviceId));
log.info("Total allocated dev_0: {}", MemoryTracker.getInstance().getActiveMemory(0));
log.info("Cached dev_0: {}", MemoryTracker.getInstance().getCachedAmount(0));
log.info("Allocated dev_0: {}", MemoryTracker.getInstance().getAllocatedAmount(0));
log.info("Workspace dev_0: {}", MemoryTracker.getInstance().getWorkspaceAllocatedAmount(0));
//log.info("Total allocated dev_1: {}", MemoryTracker.getInstance().getActiveMemory(1));
// if device memory allocation failed (aka returned NULL), keep using host memory instead
returnPair.setDevicePointer(tmpPair.getHostPointer());
point.setAllocationStatus(AllocationStatus.HOST);
Nd4j.getMemoryManager().invokeGc();
try {
Thread.sleep(100);
} catch (Exception e) {
}
}
} else {
log.warn("Hard limit on [DEVICE] memory hit, please consider tuning memory parameters, deviceId [{}]",
deviceId);
Nd4j.getMemoryManager().invokeGc();
try {
Thread.sleep(100);
} catch (InterruptedException e) {
//
}
}
return returnPair;
}
default:
throw new IllegalStateException("Can't allocate memory on target [" + targetMode + "]");
}
} }
/** /**
@ -356,7 +214,7 @@ public class CudaZeroHandler implements MemoryHandler {
*/ */
@Override @Override
public boolean pingDeviceForFreeMemory(Integer deviceId, long requiredMemory) { public boolean pingDeviceForFreeMemory(Integer deviceId, long requiredMemory) {
return memoryProvider.pingDeviceForFreeMemory(deviceId, requiredMemory); return true;
} }
/** /**
@ -371,47 +229,7 @@ public class CudaZeroHandler implements MemoryHandler {
@Override @Override
public void relocate(AllocationStatus currentStatus, AllocationStatus targetStatus, AllocationPoint point, public void relocate(AllocationStatus currentStatus, AllocationStatus targetStatus, AllocationPoint point,
AllocationShape shape, CudaContext context) { AllocationShape shape, CudaContext context) {
//log.info("RELOCATE CALLED: [" +currentStatus+ "] -> ["+targetStatus+"]");
if (currentStatus == AllocationStatus.DEVICE && targetStatus == AllocationStatus.HOST) {
// DEVICE -> HOST
DataBuffer targetBuffer = point.getBuffer();
if (targetBuffer == null)
throw new IllegalStateException("Target buffer is NULL!");
Pointer devicePointer = new CudaPointer(point.getPointers().getDevicePointer().address());
} else if (currentStatus == AllocationStatus.HOST && targetStatus == AllocationStatus.DEVICE) {
// HOST -> DEVICE
// TODO: this probably should be removed
if (point.isConstant()) {
//log.info("Skipping relocation for constant");
return;
}
if (point.getPointers().getDevicePointer() == null) {
throw new IllegalStateException("devicePointer is NULL!");
}
val profD = PerformanceTracker.getInstance().helperStartTransaction();
if (nativeOps.memcpyAsync(point.getPointers().getDevicePointer(), point.getPointers().getHostPointer(),
AllocationUtils.getRequiredMemory(shape), CudaConstants.cudaMemcpyHostToDevice,
context.getSpecialStream()) == 0)
throw new IllegalStateException("MemcpyAsync relocate H2D failed: [" + point.getHostPointer().address()
+ "] -> [" + point.getDevicePointer().address() + "]");
flowController.commitTransfer(context.getSpecialStream());
PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), profD, point.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);
//context.syncOldStream();
} else
throw new UnsupportedOperationException("Can't relocate data in requested direction: [" + currentStatus
+ "] -> [" + targetStatus + "]");
} }
/** /**
@ -440,11 +258,6 @@ public class CudaZeroHandler implements MemoryHandler {
@Override @Override
@Deprecated @Deprecated
public void copyforward(AllocationPoint point, AllocationShape shape) { public void copyforward(AllocationPoint point, AllocationShape shape) {
/*
Technically that's just a case for relocate, with source as HOST and target point.getAllocationStatus()
*/
log.info("copyforward() called on tp[" + point.getObjectId() + "], shape: " + point.getShape());
//relocate(AllocationStatus.HOST, point.getAllocationStatus(), point, shape);
throw new UnsupportedOperationException("Deprecated call"); throw new UnsupportedOperationException("Deprecated call");
} }
@ -467,15 +280,7 @@ public class CudaZeroHandler implements MemoryHandler {
*/ */
@Override @Override
public void free(AllocationPoint point, AllocationStatus target) { public void free(AllocationPoint point, AllocationStatus target) {
//if (point.getAllocationStatus() == AllocationStatus.DEVICE)
//deviceAllocations.get(point.getDeviceId()).remove(point.getObjectId());
//zeroAllocations.get(point.getBucketId()).remove(point.getObjectId());
if (point.getAllocationStatus() == AllocationStatus.DEVICE)
deviceMemoryTracker.subFromAllocation(Thread.currentThread().getId(), point.getDeviceId(),
AllocationUtils.getRequiredMemory(point.getShape()));
memoryProvider.free(point);
} }
/** /**
@ -525,7 +330,7 @@ public class CudaZeroHandler implements MemoryHandler {
CudaContext tContext = null; CudaContext tContext = null;
if (dstBuffer.isConstant()) { if (dstBuffer.isConstant()) {
org.bytedeco.javacpp.Pointer dstPointer = new CudaPointer(point.getPointers().getHostPointer().address() + dstOffset, 0L); org.bytedeco.javacpp.Pointer dstPointer = new CudaPointer(point.getHostPointer().address() + dstOffset, 0L);
org.bytedeco.javacpp.Pointer srcPointerJ = new CudaPointer(srcPointer, length); org.bytedeco.javacpp.Pointer srcPointerJ = new CudaPointer(srcPointer, length);
val profD = PerformanceTracker.getInstance().helperStartTransaction(); val profD = PerformanceTracker.getInstance().helperStartTransaction();
@ -534,14 +339,34 @@ public class CudaZeroHandler implements MemoryHandler {
point.tickHostRead(); point.tickHostRead();
} else { } else {
// if we're copying something into host memory, but we're on device - we need to provide exact copy to device as well
Pointer rDP = new CudaPointer(point.getDevicePointer().address() + dstOffset);
if (tContext == null)
tContext = flowController.prepareAction(point);
var prof = PerformanceTracker.getInstance().helperStartTransaction();
flowController.commitTransfer(tContext.getSpecialStream());
if (nativeOps.memcpyAsync(rDP, srcPointer, length, CudaConstants.cudaMemcpyHostToDevice, tContext.getSpecialStream()) == 0)
throw new IllegalStateException("MemcpyAsync H2D failed: [" + srcPointer.address() + "] -> [" + rDP.address() + "]");
flowController.commitTransfer(tContext.getSpecialStream());
PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), prof, point.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);
flowController.registerAction(tContext, point);
point.tickDeviceWrite();
// we optionally copy to host memory // we optionally copy to host memory
if (point.getPointers().getHostPointer() != null) { if (point.getHostPointer() != null) {
Pointer dP = new CudaPointer((point.getPointers().getHostPointer().address()) + dstOffset); Pointer dP = new CudaPointer((point.getHostPointer().address()) + dstOffset);
CudaContext context = flowController.prepareAction(point); CudaContext context = flowController.prepareAction(point);
tContext = context; tContext = context;
val prof = PerformanceTracker.getInstance().helperStartTransaction(); prof = PerformanceTracker.getInstance().helperStartTransaction();
if (nativeOps.memcpyAsync(dP, srcPointer, length, CudaConstants.cudaMemcpyHostToHost, context.getSpecialStream()) == 0) if (nativeOps.memcpyAsync(dP, srcPointer, length, CudaConstants.cudaMemcpyHostToHost, context.getSpecialStream()) == 0)
throw new IllegalStateException("MemcpyAsync H2H failed: [" + srcPointer.address() + "] -> [" + dP.address() + "]"); throw new IllegalStateException("MemcpyAsync H2H failed: [" + srcPointer.address() + "] -> [" + dP.address() + "]");
@ -552,28 +377,10 @@ public class CudaZeroHandler implements MemoryHandler {
if (point.getAllocationStatus() == AllocationStatus.HOST) if (point.getAllocationStatus() == AllocationStatus.HOST)
flowController.registerAction(context, point); flowController.registerAction(context, point);
point.tickHostRead();
} }
} }
// if we're copying something into host memory, but we're on device - we need to provide exact copy to device as well
if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
Pointer rDP = new CudaPointer(point.getPointers().getDevicePointer().address() + dstOffset);
if (tContext == null)
tContext = flowController.prepareAction(point);
val prof = PerformanceTracker.getInstance().helperStartTransaction();
if (nativeOps.memcpyAsync(rDP, srcPointer, length, CudaConstants.cudaMemcpyHostToDevice, tContext.getSpecialStream()) == 0)
throw new IllegalStateException("MemcpyAsync H2D failed: [" + srcPointer.address() + "] -> [" + rDP.address() + "]");
flowController.commitTransfer(tContext.getSpecialStream());
PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), prof, point.getNumberOfBytes(),MemcpyDirection.HOST_TO_DEVICE);
flowController.registerAction(tContext, point);
point.tickDeviceWrite();
}
} }
@Override @Override
@ -581,7 +388,7 @@ public class CudaZeroHandler implements MemoryHandler {
CudaContext context) { CudaContext context) {
AllocationPoint point = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint(); AllocationPoint point = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint();
Pointer dP = new CudaPointer((point.getPointers().getDevicePointer().address()) + dstOffset); Pointer dP = new CudaPointer((point.getDevicePointer().address()) + dstOffset);
if (nativeOps.memcpyAsync(dP, srcPointer, length, CudaConstants.cudaMemcpyDeviceToDevice, context.getOldStream()) == 0) if (nativeOps.memcpyAsync(dP, srcPointer, length, CudaConstants.cudaMemcpyDeviceToDevice, context.getOldStream()) == 0)
throw new ND4JIllegalStateException("memcpyAsync failed"); throw new ND4JIllegalStateException("memcpyAsync failed");
@ -604,7 +411,7 @@ public class CudaZeroHandler implements MemoryHandler {
CudaContext context = getCudaContext(); CudaContext context = getCudaContext();
AllocationPoint point = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint(); AllocationPoint point = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint();
Pointer dP = new CudaPointer((point.getPointers().getHostPointer().address()) + dstOffset); Pointer dP = new CudaPointer((point.getHostPointer().address()) + dstOffset);
val profH = PerformanceTracker.getInstance().helperStartTransaction(); val profH = PerformanceTracker.getInstance().helperStartTransaction();
@ -614,7 +421,7 @@ public class CudaZeroHandler implements MemoryHandler {
PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), profH, point.getNumberOfBytes(),MemcpyDirection.HOST_TO_HOST); PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), profH, point.getNumberOfBytes(),MemcpyDirection.HOST_TO_HOST);
if (point.getAllocationStatus() == AllocationStatus.DEVICE) { if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
Pointer rDP = new CudaPointer(point.getPointers().getDevicePointer().address() + dstOffset); Pointer rDP = new CudaPointer(point.getDevicePointer().address() + dstOffset);
val profD = PerformanceTracker.getInstance().helperStartTransaction(); val profD = PerformanceTracker.getInstance().helperStartTransaction();
@ -717,23 +524,22 @@ public class CudaZeroHandler implements MemoryHandler {
@Override @Override
public org.bytedeco.javacpp.Pointer getDevicePointer(DataBuffer buffer, CudaContext context) { public org.bytedeco.javacpp.Pointer getDevicePointer(DataBuffer buffer, CudaContext context) {
// TODO: It would be awesome to get rid of typecasting here // TODO: It would be awesome to get rid of typecasting here
//getCudaContext().syncOldStream();
AllocationPoint dstPoint = ((BaseCudaDataBuffer) buffer).getAllocationPoint(); AllocationPoint dstPoint = ((BaseCudaDataBuffer) buffer).getAllocationPoint();
// if that's device state, we probably might want to update device memory state // if that's device state, we probably might want to update device memory state
if (dstPoint.getAllocationStatus() == AllocationStatus.DEVICE) { if (dstPoint.getAllocationStatus() == AllocationStatus.DEVICE) {
if (!dstPoint.isActualOnDeviceSide()) { if (!dstPoint.isActualOnDeviceSide()) {
// log.info("Relocating to GPU"); //relocate(AllocationStatus.HOST, AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), context);
relocate(AllocationStatus.HOST, AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), context); throw new UnsupportedOperationException("Pew-pew");
} }
} }
// we update memory use counter, to announce that it's somehow used on device if (dstPoint.getDevicePointer() == null)
dstPoint.tickDeviceRead(); return null;
// return pointer with offset if needed. length is specified for constructor compatibility purposes
val p = new CudaPointer(dstPoint.getPointers().getDevicePointer(), buffer.length(), // return pointer. length is specified for constructor compatibility purposes. Offset is accounted at C++ side
(buffer.offset() * buffer.getElementSize())); val p = new CudaPointer(dstPoint.getDevicePointer(), buffer.length(), 0);
if (OpProfiler.getInstance().getConfig().isCheckLocality()) if (OpProfiler.getInstance().getConfig().isCheckLocality())
NativeOpsHolder.getInstance().getDeviceNativeOps().tryPointer(context.getOldStream(), p, 1); NativeOpsHolder.getInstance().getDeviceNativeOps().tryPointer(context.getOldStream(), p, 1);
@ -749,10 +555,17 @@ public class CudaZeroHandler implements MemoryHandler {
case SHORT: case SHORT:
case UINT16: case UINT16:
case HALF: case HALF:
case BFLOAT16:
return p.asShortPointer(); return p.asShortPointer();
case UINT64: case UINT64:
case LONG: case LONG:
return p.asLongPointer(); return p.asLongPointer();
case UTF8:
case UBYTE:
case BYTE:
return p.asBytePointer();
case BOOL:
return p.asBooleanPointer();
default: default:
return p; return p;
} }
@ -769,17 +582,14 @@ public class CudaZeroHandler implements MemoryHandler {
AllocationPoint dstPoint = ((BaseCudaDataBuffer) buffer).getAllocationPoint(); AllocationPoint dstPoint = ((BaseCudaDataBuffer) buffer).getAllocationPoint();
// return pointer with offset if needed. length is specified for constructor compatibility purposes // return pointer with offset if needed. length is specified for constructor compatibility purposes
if (dstPoint.getPointers().getHostPointer() == null) { if (dstPoint.getHostPointer() == null) {
return null; return null;
} }
//dstPoint.tickHostWrite();
//dstPoint.tickHostRead();
//log.info("Requesting host pointer for {}", buffer);
//getCudaContext().syncOldStream();
synchronizeThreadDevice(Thread.currentThread().getId(), dstPoint.getDeviceId(), dstPoint); synchronizeThreadDevice(Thread.currentThread().getId(), dstPoint.getDeviceId(), dstPoint);
CudaPointer p = new CudaPointer(dstPoint.getPointers().getHostPointer(), buffer.length(), CudaPointer p = new CudaPointer(dstPoint.getHostPointer(), buffer.length(), 0);
(buffer.offset() * buffer.getElementSize()));
switch (buffer.dataType()) { switch (buffer.dataType()) {
case DOUBLE: case DOUBLE:
return p.asDoublePointer(); return p.asDoublePointer();
@ -805,6 +615,9 @@ public class CudaZeroHandler implements MemoryHandler {
public synchronized void relocateObject(DataBuffer buffer) { public synchronized void relocateObject(DataBuffer buffer) {
AllocationPoint dstPoint = AtomicAllocator.getInstance().getAllocationPoint(buffer); AllocationPoint dstPoint = AtomicAllocator.getInstance().getAllocationPoint(buffer);
if (1 > 0)
throw new UnsupportedOperationException("Pew-pew");
// we don't relocate non-DEVICE buffers (i.e HOST or CONSTANT) // we don't relocate non-DEVICE buffers (i.e HOST or CONSTANT)
if (dstPoint.getAllocationStatus() != AllocationStatus.DEVICE) if (dstPoint.getAllocationStatus() != AllocationStatus.DEVICE)
return; return;
@ -838,14 +651,14 @@ public class CudaZeroHandler implements MemoryHandler {
// if we're out of workspace, we should mark our buffer as detached, so gc will pick it up eventually // if we're out of workspace, we should mark our buffer as detached, so gc will pick it up eventually
// host part is optional // host part is optional
if (dstPoint.getHostPointer() != null) { if (dstPoint.getHostPointer() != null) {
val pairH = alloc(AllocationStatus.HOST, dstPoint, dstPoint.getShape(), false); //val pairH = alloc(AllocationStatus.HOST, dstPoint, dstPoint.getShape(), false);
dstPoint.getPointers().setHostPointer(pairH.getHostPointer()); //dstPoint.getPointers().setHostPointer(pairH.getHostPointer());
} }
val pairD = alloc(AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), false); //val pairD = alloc(AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), false);
dstPoint.getPointers().setDevicePointer(pairD.getDevicePointer()); //dstPoint.getPointers().setDevicePointer(pairD.getDevicePointer());
//log.info("New host pointer: {}; Old host pointer: {}", dstPoint.getHostPointer().address(), ohPtr.address()); ////log.info("New host pointer: {}; Old host pointer: {}", dstPoint.getHostPointer().address(), ohPtr.address());
CudaContext context = getCudaContext(); CudaContext context = getCudaContext();
@ -876,10 +689,10 @@ public class CudaZeroHandler implements MemoryHandler {
Nd4j.getMemoryManager().memcpy(nBuffer, buffer); Nd4j.getMemoryManager().memcpy(nBuffer, buffer);
dstPoint.getPointers().setDevicePointer(nBuffer.getAllocationPoint().getDevicePointer()); //dstPoint.getPointers().setDevicePointer(nBuffer.getAllocationPoint().getDevicePointer());
if (dstPoint.getHostPointer() != null) { if (dstPoint.getHostPointer() != null) {
dstPoint.getPointers().setHostPointer(nBuffer.getAllocationPoint().getHostPointer()); // dstPoint.getPointers().setHostPointer(nBuffer.getAllocationPoint().getHostPointer());
} }
dstPoint.setDeviceId(deviceId); dstPoint.setDeviceId(deviceId);
@ -908,11 +721,10 @@ public class CudaZeroHandler implements MemoryHandler {
context.syncSpecialStream(); context.syncSpecialStream();
} }
memoryProvider.free(dstPoint); //deviceMemoryTracker.subFromAllocation(Thread.currentThread().getId(), dstPoint.getDeviceId(), AllocationUtils.getRequiredMemory(dstPoint.getShape()));
deviceMemoryTracker.subFromAllocation(Thread.currentThread().getId(), dstPoint.getDeviceId(), AllocationUtils.getRequiredMemory(dstPoint.getShape()));
// we replace original device pointer with new one // we replace original device pointer with new one
alloc(AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), false); //alloc(AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), false);
val profD = PerformanceTracker.getInstance().helperStartTransaction(); val profD = PerformanceTracker.getInstance().helperStartTransaction();
@ -940,6 +752,9 @@ public class CudaZeroHandler implements MemoryHandler {
public boolean promoteObject(DataBuffer buffer) { public boolean promoteObject(DataBuffer buffer) {
AllocationPoint dstPoint = AtomicAllocator.getInstance().getAllocationPoint(buffer); AllocationPoint dstPoint = AtomicAllocator.getInstance().getAllocationPoint(buffer);
if (1 > 0)
throw new UnsupportedOperationException("Pew-pew");
if (dstPoint.getAllocationStatus() != AllocationStatus.HOST) if (dstPoint.getAllocationStatus() != AllocationStatus.HOST)
return false; return false;
@ -952,20 +767,19 @@ public class CudaZeroHandler implements MemoryHandler {
Nd4j.getConstantHandler().moveToConstantSpace(buffer); Nd4j.getConstantHandler().moveToConstantSpace(buffer);
} else { } else {
PointersPair pair = memoryProvider.malloc(dstPoint.getShape(), dstPoint, AllocationStatus.DEVICE); PointersPair pair = null; //memoryProvider.malloc(dstPoint.getShape(), dstPoint, AllocationStatus.DEVICE);
if (pair != null) { if (pair != null) {
Integer deviceId = getDeviceId(); Integer deviceId = getDeviceId();
// log.info("Promoting object to device: [{}]", deviceId); // log.info("Promoting object to device: [{}]", deviceId);
dstPoint.getPointers().setDevicePointer(pair.getDevicePointer()); //dstPoint.setDevicePointer(pair.getDevicePointer());
dstPoint.setAllocationStatus(AllocationStatus.DEVICE); dstPoint.setAllocationStatus(AllocationStatus.DEVICE);
deviceAllocations.get(deviceId).put(dstPoint.getObjectId(), dstPoint.getObjectId()); deviceAllocations.get(deviceId).put(dstPoint.getObjectId(), dstPoint.getObjectId());
zeroAllocations.get(dstPoint.getBucketId()).remove(dstPoint.getObjectId()); zeroAllocations.get(dstPoint.getBucketId()).remove(dstPoint.getObjectId());
deviceMemoryTracker.addToAllocation(Thread.currentThread().getId(), deviceId, //deviceMemoryTracker.addToAllocation(Thread.currentThread().getId(), deviceId, AllocationUtils.getRequiredMemory(dstPoint.getShape()));
AllocationUtils.getRequiredMemory(dstPoint.getShape()));
dstPoint.tickHostWrite(); dstPoint.tickHostWrite();
@ -1103,7 +917,7 @@ public class CudaZeroHandler implements MemoryHandler {
if (deviceAllocations.get(deviceId).containsKey(objectId)) if (deviceAllocations.get(deviceId).containsKey(objectId))
throw new IllegalStateException("Can't happen ever"); throw new IllegalStateException("Can't happen ever");
deviceMemoryTracker.subFromAllocation(threadId, deviceId, AllocationUtils.getRequiredMemory(point.getShape())); //deviceMemoryTracker.subFromAllocation(threadId, deviceId, AllocationUtils.getRequiredMemory(point.getShape()));
point.setAllocationStatus(AllocationStatus.HOST); point.setAllocationStatus(AllocationStatus.HOST);
@ -1119,6 +933,9 @@ public class CudaZeroHandler implements MemoryHandler {
*/ */
@Override @Override
public void purgeZeroObject(Long bucketId, Long objectId, AllocationPoint point, boolean copyback) { public void purgeZeroObject(Long bucketId, Long objectId, AllocationPoint point, boolean copyback) {
if (1 > 0)
throw new UnsupportedOperationException("Pew-pew");
forget(point, AllocationStatus.HOST); forget(point, AllocationStatus.HOST);
flowController.waitTillReleased(point); flowController.waitTillReleased(point);
@ -1127,8 +944,8 @@ public class CudaZeroHandler implements MemoryHandler {
if (point.getHostPointer() != null) { if (point.getHostPointer() != null) {
free(point, AllocationStatus.HOST); free(point, AllocationStatus.HOST);
long reqMem = AllocationUtils.getRequiredMemory(point.getShape()) * -1; //long reqMem = AllocationUtils.getRequiredMemory(point.getShape()) * -1;
zeroUseCounter.addAndGet(reqMem); //zeroUseCounter.addAndGet(reqMem);
} }
point.setAllocationStatus(AllocationStatus.DEALLOCATED); point.setAllocationStatus(AllocationStatus.DEALLOCATED);
@ -1252,4 +1069,9 @@ public class CudaZeroHandler implements MemoryHandler {
public FlowController getFlowController() { public FlowController getFlowController() {
return flowController; return flowController;
} }
@Override
public MemoryProvider getMemoryProvider() {
return null;
}
} }

View File

@ -147,7 +147,7 @@ public class CudaMemoryManager extends BasicMemoryManager {
// Nd4j.getShapeInfoProvider().purgeCache(); // Nd4j.getShapeInfoProvider().purgeCache();
// purge memory cache // purge memory cache
AtomicAllocator.getInstance().getMemoryHandler().getMemoryProvider().purgeCache(); //AtomicAllocator.getInstance().getMemoryHandler().getMemoryProvider().purgeCache();
} }

View File

@ -1,303 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.jita.memory.impl;
import lombok.val;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AllocationShape;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.PointersPair;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.jita.conf.Configuration;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.jita.memory.MemoryProvider;
import org.slf4j.Logger;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import org.nd4j.jita.allocator.impl.MemoryTracker;
/**
* This is MemoryProvider implementation, that adds cache for memory reuse purposes. Only host memory is cached for future reuse.
*
* If some memory chunk gets released via allocator, it'll be probably saved for future reused within same JVM process.
*
* @author raver119@gmail.com
*/
public class CudaCachingZeroProvider extends CudaDirectProvider implements MemoryProvider {
private static Logger log = LoggerFactory.getLogger(CudaCachingZeroProvider.class);
protected volatile ConcurrentHashMap<AllocationShape, CacheHolder> zeroCache = new ConcurrentHashMap<>();
protected final AtomicLong cacheZeroHit = new AtomicLong(0);
protected final AtomicLong cacheZeroMiss = new AtomicLong(0);
protected final AtomicLong cacheDeviceHit = new AtomicLong(0);
protected final AtomicLong cacheDeviceMiss = new AtomicLong(0);
private final AtomicLong allocRequests = new AtomicLong(0);
protected final AtomicLong zeroCachedAmount = new AtomicLong(0);
protected List<AtomicLong> deviceCachedAmount = new ArrayList<>();
protected final Semaphore singleLock = new Semaphore(1);
// we don't cache allocations greater then this value
//protected final long MAX_SINGLE_ALLOCATION = configuration.getMaximumHostCacheableLength();
// maximum cached size of memory
//protected final long MAX_CACHED_MEMORY = configuration.getMaximumHostCache();
// memory chunks below this threshold will be guaranteed regardless of number of cache entries
// that especially covers all possible variations of shapeInfoDataBuffers in all possible cases
protected final long FORCED_CACHE_THRESHOLD = 96;
// number of preallocation entries for each yet-unknown shape
//protected final int PREALLOCATION_LIMIT = configuration.getPreallocationCalls();
public CudaCachingZeroProvider() {
}
/**
* This method provides PointersPair to memory chunk specified by AllocationShape
*
* PLEASE NOTE: This method can actually ignore malloc request, and give out previously cached free memory chunk with equal shape.
*
* @param shape shape of desired memory chunk
* @param point target AllocationPoint structure
* @param location either HOST or DEVICE
* @return
*/
@Override
public PointersPair malloc(AllocationShape shape, AllocationPoint point, AllocationStatus location) {
long reqMemory = AllocationUtils.getRequiredMemory(shape);
if (location == AllocationStatus.HOST && reqMemory < CudaEnvironment.getInstance().getConfiguration().getMaximumHostCacheableLength()) {
val cache = zeroCache.get(shape);
if (cache != null) {
val pointer = cache.poll();
if (pointer != null) {
cacheZeroHit.incrementAndGet();
// since this memory chunk is going to be used now, remove it's amount from
zeroCachedAmount.addAndGet(-1 * reqMemory);
val pair = new PointersPair();
pair.setDevicePointer(new CudaPointer(pointer.address()));
pair.setHostPointer(new CudaPointer(pointer.address()));
point.setAllocationStatus(AllocationStatus.HOST);
MemoryTracker.getInstance().incrementAllocatedHostAmount(reqMemory);
MemoryTracker.getInstance().decrementCachedHostAmount(reqMemory);
return pair;
}
}
cacheZeroMiss.incrementAndGet();
if (CudaEnvironment.getInstance().getConfiguration().isUsePreallocation() && zeroCachedAmount.get() < CudaEnvironment.getInstance().getConfiguration().getMaximumHostCache() / 10
&& reqMemory < 16 * 1024 * 1024L) {
val preallocator = new CachePreallocator(shape, location, CudaEnvironment.getInstance().getConfiguration().getPreallocationCalls());
preallocator.start();
}
cacheZeroMiss.incrementAndGet();
return super.malloc(shape, point, location);
}
return super.malloc(shape, point, location);
}
protected void ensureCacheHolder(AllocationShape shape) {
if (!zeroCache.containsKey(shape)) {
try {
singleLock.acquire();
if (!zeroCache.containsKey(shape)) {
zeroCache.put(shape, new CacheHolder(shape, zeroCachedAmount));
}
} catch (Exception e) {
throw new RuntimeException(e);
} finally {
singleLock.release();
}
}
}
/**
* This method frees specific chunk of memory, described by AllocationPoint passed in.
*
* PLEASE NOTE: This method can actually ignore free, and keep released memory chunk for future reuse.
*
* @param point
*/
@Override
public void free(AllocationPoint point) {
if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
super.free(point);
} else {
// if this point has no allocated chunk - step over it
if (point.getHostPointer() == null)
return;
AllocationShape shape = point.getShape();
long reqMemory = AllocationUtils.getRequiredMemory(shape);
// we don't cache too big objects
if (reqMemory > CudaEnvironment.getInstance().getConfiguration().getMaximumHostCacheableLength() || zeroCachedAmount.get() >= CudaEnvironment.getInstance().getConfiguration().getMaximumHostCache()) {
super.free(point);
return;
}
ensureCacheHolder(shape);
/*
Now we should decide if this object can be cached or not
*/
CacheHolder cache = zeroCache.get(shape);
// memory chunks < threshold will be cached no matter what
if (reqMemory <= FORCED_CACHE_THRESHOLD) {
Pointer.memset(point.getHostPointer(), 0, reqMemory);
cache.put(new CudaPointer(point.getHostPointer().address()));
} else {
long cacheEntries = cache.size();
long cacheHeight = zeroCache.size();
// total memory allocated within this bucket
long cacheDepth = cacheEntries * reqMemory;
Pointer.memset(point.getHostPointer(), 0, reqMemory);
cache.put(new CudaPointer(point.getHostPointer().address()));
}
MemoryTracker.getInstance().decrementAllocatedHostAmount(reqMemory);
MemoryTracker.getInstance().incrementCachedHostAmount(reqMemory);
}
}
private float getZeroCacheHitRatio() {
long totalHits = cacheZeroHit.get() + cacheZeroMiss.get();
float cacheRatio = cacheZeroHit.get() * 100 / (float) totalHits;
return cacheRatio;
}
private float getDeviceCacheHitRatio() {
long totalHits = cacheDeviceHit.get() + cacheDeviceMiss.get();
float cacheRatio = cacheDeviceHit.get() * 100 / (float) totalHits;
return cacheRatio;
}
@Deprecated
public void printCacheStats() {
log.debug("Cached host amount: " + zeroCachedAmount.get());
log.debug("Cached device amount: " + deviceCachedAmount.get(0).get());
log.debug("Total shapes in cache: " + zeroCache.size());
log.debug("Current host hit ratio: " + getZeroCacheHitRatio());
log.debug("Current device hit ratio: " + getDeviceCacheHitRatio());
}
protected class CacheHolder {
private Queue<Pointer> queue = new ConcurrentLinkedQueue<>();
private volatile int counter = 0;
private long reqMem = 0;
private final AtomicLong allocCounter;
public CacheHolder(AllocationShape shape, AtomicLong counter) {
this.reqMem = AllocationUtils.getRequiredMemory(shape);
this.allocCounter = counter;
}
public synchronized int size() {
return counter;
}
public synchronized Pointer poll() {
val pointer = queue.poll();
if (pointer != null)
counter--;
return pointer;
}
public synchronized void put(Pointer pointer) {
allocCounter.addAndGet(reqMem);
counter++;
queue.add(pointer);
}
}
protected class CachePreallocator extends Thread implements Runnable {
private AllocationShape shape;
private AllocationStatus location;
private int target;
public CachePreallocator(AllocationShape shape, AllocationStatus location, int numberOfEntries) {
this.shape = shape;
this.target = numberOfEntries;
this.location = location;
}
@Override
public void run() {
ensureCacheHolder(shape);
for (int i = 0; i < target; i++) {
val point = new AllocationPoint();
val pair = CudaCachingZeroProvider.super.malloc(shape, point, this.location);
if (this.location == AllocationStatus.HOST) {
Pointer pointer = new CudaPointer(pair.getHostPointer().address());
CudaCachingZeroProvider.this.zeroCache.get(shape).put(pointer);
}
}
}
}
@Override
public void purgeCache() {
for (AllocationShape shape : zeroCache.keySet()) {
Pointer ptr = null;
while ((ptr = zeroCache.get(shape).poll()) != null) {
freeHost(ptr);
MemoryTracker.getInstance().decrementCachedHostAmount(shape.getNumberOfBytes());
}
}
zeroCachedAmount.set(0);
}
}

View File

@ -1,239 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.jita.memory.impl;
import lombok.val;
import lombok.var;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AllocationShape;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.PointersPair;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.jita.memory.MemoryProvider;
import org.nd4j.linalg.api.memory.AllocationsTracker;
import org.nd4j.linalg.api.memory.enums.AllocationKind;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.nd4j.jita.allocator.impl.MemoryTracker;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
/**
* @author raver119@gmail.com
*/
public class CudaDirectProvider implements MemoryProvider {
protected static final long DEVICE_RESERVED_SPACE = 1024 * 1024 * 50L;
private static Logger log = LoggerFactory.getLogger(CudaDirectProvider.class);
protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
protected volatile ConcurrentHashMap<Long, Integer> validator = new ConcurrentHashMap<>();
private AtomicLong emergencyCounter = new AtomicLong(0);
/**
* This method provides PointersPair to memory chunk specified by AllocationShape
*
* @param shape shape of desired memory chunk
* @param point target AllocationPoint structure
* @param location either HOST or DEVICE
* @return
*/
@Override
public PointersPair malloc(AllocationShape shape, AllocationPoint point, AllocationStatus location) {
//log.info("shape onCreate: {}, target: {}", shape, location);
switch (location) {
case HOST: {
long reqMem = AllocationUtils.getRequiredMemory(shape);
// FIXME: this is WRONG, and directly leads to memleak
if (reqMem < 1)
reqMem = 1;
val pointer = nativeOps.mallocHost(reqMem, 0);
if (pointer == null)
throw new RuntimeException("Can't allocate [HOST] memory: " + reqMem + "; threadId: "
+ Thread.currentThread().getId());
// log.info("Host allocation, Thread id: {}, ReqMem: {}, Pointer: {}", Thread.currentThread().getId(), reqMem, pointer != null ? pointer.address() : null);
val hostPointer = new CudaPointer(pointer);
val devicePointerInfo = new PointersPair();
if (point.getPointers().getDevicePointer() == null) {
point.setAllocationStatus(AllocationStatus.HOST);
devicePointerInfo.setDevicePointer(new CudaPointer(hostPointer, reqMem));
} else
devicePointerInfo.setDevicePointer(point.getDevicePointer());
devicePointerInfo.setHostPointer(new CudaPointer(hostPointer, reqMem));
point.setPointers(devicePointerInfo);
MemoryTracker.getInstance().incrementAllocatedHostAmount(reqMem);
return devicePointerInfo;
}
case DEVICE: {
// cudaMalloc call
val deviceId = AtomicAllocator.getInstance().getDeviceId();
long reqMem = AllocationUtils.getRequiredMemory(shape);
// FIXME: this is WRONG, and directly leads to memleak
if (reqMem < 1)
reqMem = 1;
AllocationsTracker.getInstance().markAllocated(AllocationKind.GENERAL, deviceId, reqMem);
var pointer = nativeOps.mallocDevice(reqMem, deviceId, 0);
if (pointer == null) {
// try to purge stuff if we're low on memory
purgeCache(deviceId);
// call for gc
Nd4j.getMemoryManager().invokeGc();
pointer = nativeOps.mallocDevice(reqMem, deviceId, 0);
if (pointer == null)
return null;
}
val devicePointer = new CudaPointer(pointer);
var devicePointerInfo = point.getPointers();
if (devicePointerInfo == null)
devicePointerInfo = new PointersPair();
devicePointerInfo.setDevicePointer(new CudaPointer(devicePointer, reqMem));
point.setAllocationStatus(AllocationStatus.DEVICE);
point.setDeviceId(deviceId);
MemoryTracker.getInstance().incrementAllocatedAmount(deviceId, reqMem);
return devicePointerInfo;
}
default:
throw new IllegalStateException("Unsupported location for malloc: [" + location + "]");
}
}
/**
* This method frees specific chunk of memory, described by AllocationPoint passed in
*
* @param point
*/
@Override
public void free(AllocationPoint point) {
switch (point.getAllocationStatus()) {
case HOST: {
// cudaFreeHost call here
long reqMem = AllocationUtils.getRequiredMemory(point.getShape());
val nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
long result = nativeOps.freeHost(point.getPointers().getHostPointer());
if (result == 0) {
throw new RuntimeException("Can't deallocate [HOST] memory...");
}
MemoryTracker.getInstance().decrementAllocatedHostAmount(reqMem);
}
break;
case DEVICE: {
if (point.isConstant())
return;
long reqMem = AllocationUtils.getRequiredMemory(point.getShape());
val nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
AllocationsTracker.getInstance().markReleased(AllocationKind.GENERAL, point.getDeviceId(), reqMem);
val pointers = point.getPointers();
long result = nativeOps.freeDevice(pointers.getDevicePointer(), 0);
if (result == 0)
throw new RuntimeException("Can't deallocate [DEVICE] memory...");
MemoryTracker.getInstance().decrementAllocatedAmount(point.getDeviceId(), reqMem);
}
break;
default:
throw new IllegalStateException("Can't free memory on target [" + point.getAllocationStatus() + "]");
}
}
/**
* This method checks specified device for specified amount of memory
*
* @param deviceId
* @param requiredMemory
* @return
*/
public boolean pingDeviceForFreeMemory(Integer deviceId, long requiredMemory) {
/*
long[] totalMem = new long[1];
long[] freeMem = new long[1];
JCuda.cudaMemGetInfo(freeMem, totalMem);
long free = freeMem[0];
long total = totalMem[0];
long used = total - free;
/*
We don't want to allocate memory if it's too close to the end of available ram.
*/
//if (configuration != null && used > total * configuration.getMaxDeviceMemoryUsed()) return false;
/*
if (free + requiredMemory < total * 0.85)
return true;
else return false;
*/
long freeMem = nativeOps.getDeviceFreeMemory(-1);
if (freeMem - requiredMemory < DEVICE_RESERVED_SPACE)
return false;
else
return true;
}
protected void freeHost(Pointer pointer) {
val nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
nativeOps.freeHost(pointer);
}
protected void freeDevice(Pointer pointer, int deviceId) {
val nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
nativeOps.freeDevice(pointer, 0);
}
protected void purgeCache(int deviceId) {
//
}
@Override
public void purgeCache() {
// no-op
}
}

View File

@ -1,220 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.jita.memory.impl;
import lombok.val;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AllocationShape;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.impl.MemoryTracker;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.PointersPair;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
/**
* This MemoryProvider implementation does caching for both host and device memory within predefined limits.
*
* @author raver119@gmail.com
*/
public class CudaFullCachingProvider extends CudaCachingZeroProvider {
//protected final long MAX_GPU_ALLOCATION = configuration.getMaximumSingleDeviceAllocation();
//protected final long MAX_GPU_CACHE = configuration.getMaximumDeviceCache();
protected volatile ConcurrentHashMap<Integer, ConcurrentHashMap<AllocationShape, CacheHolder>> deviceCache =
new ConcurrentHashMap<>();
private static Logger log = LoggerFactory.getLogger(CudaFullCachingProvider.class);
public CudaFullCachingProvider() {
init();
}
public void init() {
int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
deviceCachedAmount = new ArrayList<>();
for (int i = 0; i < numDevices; i++) {
deviceCachedAmount.add(new AtomicLong(0));
}
}
/**
* This method provides PointersPair to memory chunk specified by AllocationShape
*
* PLEASE NOTE: This method can actually ignore malloc request, and give out previously cached free memory chunk with equal shape.
*
* @param shape shape of desired memory chunk
* @param point target AllocationPoint structure
* @param location either HOST or DEVICE
* @return
*/
@Override
public PointersPair malloc(AllocationShape shape, AllocationPoint point, AllocationStatus location) {
val reqMemory = AllocationUtils.getRequiredMemory(shape);
if (location == AllocationStatus.DEVICE && reqMemory < CudaEnvironment.getInstance().getConfiguration().getMaximumDeviceAllocation()) {
val deviceId = AtomicAllocator.getInstance().getDeviceId();
ensureDeviceCacheHolder(deviceId, shape);
val cache = deviceCache.get(deviceId).get(shape);
if (cache != null) {
val pointer = cache.poll();
if (pointer != null) {
cacheDeviceHit.incrementAndGet();
deviceCachedAmount.get(deviceId).addAndGet(-reqMemory);
val pair = new PointersPair();
pair.setDevicePointer(pointer);
point.setAllocationStatus(AllocationStatus.DEVICE);
point.setDeviceId(deviceId);
MemoryTracker.getInstance().incrementAllocatedAmount(deviceId, reqMemory);
MemoryTracker.getInstance().decrementCachedAmount(deviceId, reqMemory);
return pair;
}
}
cacheDeviceMiss.incrementAndGet();
return super.malloc(shape, point, location);
}
return super.malloc(shape, point, location);
}
/**
* This method frees specific chunk of memory, described by AllocationPoint passed in
*
* PLEASE NOTE: This method can actually ignore free, and keep released memory chunk for future reuse.
*
* @param point
*/
@Override
public void free(AllocationPoint point) {
if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
if (point.isConstant())
return;
val shape = point.getShape();
val deviceId = point.getDeviceId();
val address = point.getDevicePointer().address();
val reqMemory = AllocationUtils.getRequiredMemory(shape);
// we don't cache too big objects
if (reqMemory > CudaEnvironment.getInstance().getConfiguration().getMaximumDeviceCacheableLength() || deviceCachedAmount.get(deviceId).get() >= CudaEnvironment.getInstance().getConfiguration().getMaximumDeviceCache()) {
super.free(point);
return;
}
ensureDeviceCacheHolder(deviceId, shape);
val cache = deviceCache.get(deviceId).get(shape);
if (point.getDeviceId() != deviceId)
throw new RuntimeException("deviceId changed!");
// memory chunks < threshold will be cached no matter what
if (reqMemory <= FORCED_CACHE_THRESHOLD) {
cache.put(new CudaPointer(point.getDevicePointer().address()));
MemoryTracker.getInstance().incrementCachedAmount(deviceId, reqMemory);
MemoryTracker.getInstance().decrementAllocatedAmount(deviceId, reqMemory);
return;
} else {
cache.put(new CudaPointer(point.getDevicePointer().address()));
MemoryTracker.getInstance().incrementCachedAmount(deviceId, reqMemory);
MemoryTracker.getInstance().decrementAllocatedAmount(deviceId, reqMemory);
return;
}
}
super.free(point);
}
/**
* This method checks, if storage contains holder for specified shape
*
* @param deviceId
* @param shape
*/
protected void ensureDeviceCacheHolder(Integer deviceId, AllocationShape shape) {
if (!deviceCache.containsKey(deviceId)) {
try {
synchronized (this) {
if (!deviceCache.containsKey(deviceId)) {
deviceCache.put(deviceId, new ConcurrentHashMap<AllocationShape, CacheHolder>());
}
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
if (!deviceCache.get(deviceId).containsKey(shape)) {
try {
singleLock.acquire();
if (!deviceCache.get(deviceId).containsKey(shape)) {
deviceCache.get(deviceId).put(shape, new CacheHolder(shape, deviceCachedAmount.get(deviceId)));
}
} catch (Exception e) {
} finally {
singleLock.release();
}
}
}
@Override
protected synchronized void purgeCache(int deviceId) {
for (AllocationShape shape : deviceCache.get(deviceId).keySet()) {
Pointer ptr = null;
while ((ptr = deviceCache.get(deviceId).get(shape).poll()) != null) {
freeDevice(ptr, deviceId);
MemoryTracker.getInstance().decrementCachedAmount(deviceId, shape.getNumberOfBytes());
}
}
deviceCachedAmount.get(deviceId).set(0);
}
@Override
public synchronized void purgeCache() {
for (Integer device : deviceCache.keySet()) {
purgeCache(device);
}
super.purgeCache();
}
}

View File

@ -17,34 +17,39 @@
package org.nd4j.linalg.jcublas; package org.nd4j.linalg.jcublas;
import com.google.flatbuffers.FlatBufferBuilder;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.bytedeco.javacpp.BytePointer;
import org.nd4j.base.Preconditions;
import org.nd4j.graph.FlatArray;
import org.nd4j.jita.allocator.enums.AllocationStatus; import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.enums.CudaConstants; import org.nd4j.jita.allocator.enums.CudaConstants;
import org.nd4j.jita.allocator.impl.AllocationPoint; import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.DataTypeEx;
import org.nd4j.linalg.api.buffer.FloatBuffer;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.BaseNDArray; import org.nd4j.linalg.api.ndarray.BaseNDArray;
import org.nd4j.linalg.api.ndarray.BaseNDArrayProxy; import org.nd4j.linalg.api.ndarray.BaseNDArrayProxy;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ndarray.JvmShapeInfo; import org.nd4j.linalg.api.ndarray.JvmShapeInfo;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.api.ops.performance.PerformanceTracker; import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
import org.nd4j.linalg.api.ops.util.PrintVariable;
import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.CudaLongDataBuffer; import org.nd4j.linalg.jcublas.buffer.CudaLongDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaUtf8Buffer;
import org.nd4j.linalg.jcublas.context.CudaContext; import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.memory.MemcpyDirection; import org.nd4j.linalg.memory.MemcpyDirection;
import org.nd4j.linalg.workspace.WorkspaceUtils; import org.nd4j.linalg.workspace.WorkspaceUtils;
import org.nd4j.nativeblas.NativeOpsHolder; import org.nd4j.nativeblas.NativeOpsHolder;
import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
/** /**
* *
@ -387,10 +392,6 @@ public class JCublasNDArray extends BaseNDArray {
super(data, order); super(data, order);
} }
public JCublasNDArray(FloatBuffer floatBuffer, char order) {
super(floatBuffer, order);
}
public JCublasNDArray(DataBuffer buffer, int[] shape, int[] strides) { public JCublasNDArray(DataBuffer buffer, int[] shape, int[] strides) {
super(buffer, shape, strides); super(buffer, shape, strides);
} }
@ -574,26 +575,16 @@ public class JCublasNDArray extends BaseNDArray {
MemcpyDirection direction = MemcpyDirection.HOST_TO_HOST; MemcpyDirection direction = MemcpyDirection.HOST_TO_HOST;
val prof = PerformanceTracker.getInstance().helperStartTransaction(); val prof = PerformanceTracker.getInstance().helperStartTransaction();
if (dstPoint.getAllocationStatus() == AllocationStatus.DEVICE && srcPoint.getAllocationStatus() == AllocationStatus.DEVICE) { if (srcPoint.isActualOnDeviceSide()) {
// d2d copy
route = 1; route = 1;
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(dstPoint.getDevicePointer(), srcPoint.getDevicePointer(), this.data.length() * this.data.getElementSize(), CudaConstants.cudaMemcpyDeviceToDevice, blocking ? context.getOldStream() : context.getSpecialStream()); NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(dstPoint.getDevicePointer(), srcPoint.getDevicePointer(), this.data.length() * this.data.getElementSize(), CudaConstants.cudaMemcpyDeviceToDevice, blocking ? context.getOldStream() : context.getSpecialStream());
dstPoint.tickDeviceWrite(); dstPoint.tickDeviceWrite();
direction = MemcpyDirection.DEVICE_TO_DEVICE; direction = MemcpyDirection.DEVICE_TO_DEVICE;
} else if (dstPoint.getAllocationStatus() == AllocationStatus.HOST && srcPoint.getAllocationStatus() == AllocationStatus.DEVICE) { } else {
route = 2;
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(dstPoint.getHostPointer(), srcPoint.getDevicePointer(), this.data.length() * this.data.getElementSize(), CudaConstants.cudaMemcpyDeviceToHost, blocking ? context.getOldStream() : context.getSpecialStream());
dstPoint.tickHostWrite();
direction = MemcpyDirection.DEVICE_TO_HOST;
} else if (dstPoint.getAllocationStatus() == AllocationStatus.DEVICE && srcPoint.getAllocationStatus() == AllocationStatus.HOST) {
route = 3; route = 3;
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(dstPoint.getDevicePointer(), srcPoint.getHostPointer(), this.data.length() * this.data.getElementSize(), CudaConstants.cudaMemcpyHostToDevice, blocking ? context.getOldStream() : context.getSpecialStream()); NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(dstPoint.getDevicePointer(), srcPoint.getHostPointer(), this.data.length() * this.data.getElementSize(), CudaConstants.cudaMemcpyHostToDevice, blocking ? context.getOldStream() : context.getSpecialStream());
dstPoint.tickDeviceWrite(); dstPoint.tickDeviceWrite();
direction = MemcpyDirection.HOST_TO_DEVICE; direction = MemcpyDirection.HOST_TO_DEVICE;
} else {
route = 4;
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(dstPoint.getHostPointer(), srcPoint.getHostPointer(), this.data.length() * this.data.getElementSize(), CudaConstants.cudaMemcpyHostToHost, blocking ? context.getOldStream() : context.getSpecialStream());
dstPoint.tickHostWrite();
} }
@ -650,30 +641,16 @@ public class JCublasNDArray extends BaseNDArray {
Nd4j.getMemoryManager().setCurrentWorkspace(target); Nd4j.getMemoryManager().setCurrentWorkspace(target);
// log.info("Leveraging...");
INDArray copy = null; INDArray copy = null;
if (!this.isView()) { if (!this.isView()) {
//if (1 < 0) {
Nd4j.getExecutioner().commit(); Nd4j.getExecutioner().commit();
DataBuffer buffer = Nd4j.createBuffer(this.length(), false); val buffer = Nd4j.createBuffer(this.length(), false);
AllocationPoint pointDst = AtomicAllocator.getInstance().getAllocationPoint(buffer); val pointDst = AtomicAllocator.getInstance().getAllocationPoint(buffer);
AllocationPoint pointSrc = AtomicAllocator.getInstance().getAllocationPoint(this.data); val pointSrc = AtomicAllocator.getInstance().getAllocationPoint(this.data);
CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(pointDst, pointSrc); val context = AtomicAllocator.getInstance().getFlowController().prepareAction(pointDst, pointSrc);
/*
if (NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(pointDst.getDevicePointer(), 0, 1, 0, context.getOldStream()) == 0)
throw new ND4JIllegalStateException("memsetAsync 1 failed");
context.syncOldStream();
if (NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(pointSrc.getDevicePointer(), 0, 1, 0, context.getOldStream()) == 0)
throw new ND4JIllegalStateException("memsetAsync 2 failed");
context.syncOldStream();
*/
MemcpyDirection direction = MemcpyDirection.DEVICE_TO_DEVICE; MemcpyDirection direction = MemcpyDirection.DEVICE_TO_DEVICE;
val perfD = PerformanceTracker.getInstance().helperStartTransaction(); val perfD = PerformanceTracker.getInstance().helperStartTransaction();
@ -690,12 +667,11 @@ public class JCublasNDArray extends BaseNDArray {
context.syncOldStream(); context.syncOldStream();
PerformanceTracker.getInstance().helperRegisterTransaction(pointDst.getDeviceId(), perfD, pointSrc.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE); PerformanceTracker.getInstance().helperRegisterTransaction(pointDst.getDeviceId(), perfD, pointSrc.getNumberOfBytes(), direction);
copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer()); copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer());
// tag buffer as valid on device side // tag buffer as valid on device side
pointDst.tickHostRead();
pointDst.tickDeviceWrite(); pointDst.tickDeviceWrite();
AtomicAllocator.getInstance().getFlowController().registerAction(context, pointDst, pointSrc); AtomicAllocator.getInstance().getFlowController().registerAction(context, pointDst, pointSrc);
@ -728,6 +704,7 @@ public class JCublasNDArray extends BaseNDArray {
val pointDst = AtomicAllocator.getInstance().getAllocationPoint(buffer); val pointDst = AtomicAllocator.getInstance().getAllocationPoint(buffer);
val pointSrc = AtomicAllocator.getInstance().getAllocationPoint(this.data); val pointSrc = AtomicAllocator.getInstance().getAllocationPoint(this.data);
val context = AtomicAllocator.getInstance().getFlowController().prepareAction(pointDst, pointSrc); val context = AtomicAllocator.getInstance().getFlowController().prepareAction(pointDst, pointSrc);
MemcpyDirection direction = MemcpyDirection.DEVICE_TO_DEVICE; MemcpyDirection direction = MemcpyDirection.DEVICE_TO_DEVICE;
@ -764,6 +741,38 @@ public class JCublasNDArray extends BaseNDArray {
return copy; return copy;
} }
protected int stringBuffer(FlatBufferBuilder builder, DataBuffer buffer) {
Preconditions.checkArgument(buffer.dataType() == DataType.UTF8, "This method can be called on UTF8 buffers only");
try {
ByteArrayOutputStream bos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(bos);
val numWords = this.length();
val ub = (CudaUtf8Buffer) buffer;
// writing length first
val t = length();
val ptr = (BytePointer) ub.pointer();
// now write all strings as bytes
for (int i = 0; i < ub.length(); i++) {
dos.writeByte(ptr.get(i));
}
val bytes = bos.toByteArray();
return FlatArray.createBufferVector(builder, bytes);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
@Override
public String getString(long index) {
if (!isS())
throw new UnsupportedOperationException("This method is usable only on String dataType, but got [" + this.dataType() + "]");
return ((CudaUtf8Buffer) data).getString(index);
}
/* /*
@Override @Override
public INDArray convertToHalfs() { public INDArray convertToHalfs() {

View File

@ -18,11 +18,9 @@ package org.nd4j.linalg.jcublas;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import lombok.var;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.DataTypeEx; import org.nd4j.linalg.api.buffer.DataTypeEx;
import org.nd4j.linalg.api.buffer.Utf8Buffer;
import org.nd4j.linalg.api.memory.enums.MemoryKind; import org.nd4j.linalg.api.memory.enums.MemoryKind;
import org.nd4j.linalg.api.ops.custom.Flatten; import org.nd4j.linalg.api.ops.custom.Flatten;
import org.nd4j.linalg.api.ops.impl.shape.Concat; import org.nd4j.linalg.api.ops.impl.shape.Concat;
@ -34,12 +32,10 @@ import org.nd4j.linalg.jcublas.buffer.*;
import org.nd4j.linalg.memory.MemcpyDirection; import org.nd4j.linalg.memory.MemcpyDirection;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.bytedeco.javacpp.*; import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.indexer.*;
import org.nd4j.jita.allocator.enums.CudaConstants; import org.nd4j.jita.allocator.enums.CudaConstants;
import org.nd4j.jita.allocator.impl.AllocationPoint; import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer; import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.jita.conf.CudaEnvironment; import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
@ -51,19 +47,12 @@ import org.nd4j.linalg.compression.CompressedDataBuffer;
import org.nd4j.linalg.compression.CompressionDescriptor; import org.nd4j.linalg.compression.CompressionDescriptor;
import org.nd4j.linalg.compression.CompressionType; import org.nd4j.linalg.compression.CompressionType;
import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.BaseNDArrayFactory;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.blas.*; import org.nd4j.linalg.jcublas.blas.*;
import org.nd4j.linalg.jcublas.context.CudaContext; import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.nativeblas.*; import org.nd4j.nativeblas.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.Charset;
import java.util.*; import java.util.*;
/** /**
@ -216,7 +205,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
@Override @Override
public INDArray create(Collection<String> strings, long[] shape, char order) { public INDArray create(Collection<String> strings, long[] shape, char order) {
val pairShape = Nd4j.getShapeInfoProvider().createShapeInformation(shape, order, DataType.UTF8); val pairShape = Nd4j.getShapeInfoProvider().createShapeInformation(shape, order, DataType.UTF8);
val buffer = new Utf8Buffer(strings); val buffer = new CudaUtf8Buffer(strings);
val list = new ArrayList<String>(strings); val list = new ArrayList<String>(strings);
return Nd4j.createArrayFromShapeBuffer(buffer, pairShape); return Nd4j.createArrayFromShapeBuffer(buffer, pairShape);
} }
@ -360,8 +349,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
@Override @Override
public INDArray concat(int dimension, INDArray... toConcat) { public INDArray concat(int dimension, INDArray... toConcat) {
if (Nd4j.getExecutioner() instanceof GridExecutioner) Nd4j.getExecutioner().push();
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
return Nd4j.exec(new Concat(dimension, toConcat))[0]; return Nd4j.exec(new Concat(dimension, toConcat))[0];
} }
@ -517,9 +505,9 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
AtomicAllocator allocator = AtomicAllocator.getInstance(); AtomicAllocator allocator = AtomicAllocator.getInstance();
CudaContext context = allocator.getFlowController().prepareAction(ret, source); CudaContext context = allocator.getFlowController().prepareAction(ret, source);
Pointer x = AtomicAllocator.getInstance().getPointer(source, context); val x = ((BaseCudaDataBuffer) source.data()).getOpaqueDataBuffer();
val z = ((BaseCudaDataBuffer) ret.data()).getOpaqueDataBuffer();
Pointer xShape = AtomicAllocator.getInstance().getPointer(source.shapeInfoDataBuffer(), context); Pointer xShape = AtomicAllocator.getInstance().getPointer(source.shapeInfoDataBuffer(), context);
Pointer z = AtomicAllocator.getInstance().getPointer(ret, context);
Pointer zShape = AtomicAllocator.getInstance().getPointer(ret.shapeInfoDataBuffer(), context); Pointer zShape = AtomicAllocator.getInstance().getPointer(ret.shapeInfoDataBuffer(), context);
PointerPointer extras = new PointerPointer(AddressRetriever.retrieveHostPointer(ret.shapeInfoDataBuffer()), PointerPointer extras = new PointerPointer(AddressRetriever.retrieveHostPointer(ret.shapeInfoDataBuffer()),
@ -545,14 +533,8 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
nativeOps.pullRows(extras, nativeOps.pullRows(extras,
null, x, (LongPointer) source.shapeInfoDataBuffer().addressPointer(), (LongPointer) xShape,
(LongPointer) source.shapeInfoDataBuffer().addressPointer(), z, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), (LongPointer) zShape,
x,
(LongPointer) xShape,
null,
(LongPointer) ret.shapeInfoDataBuffer().addressPointer(),
z,
(LongPointer) zShape,
indexes.length, indexes.length,
(LongPointer) pIndex, (LongPointer) pIndex,
(LongPointer) tadShapeInfo, (LongPointer) tadShapeInfo,
@ -601,7 +583,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
throw new ND4JIllegalStateException("All arrays should have equal length for averaging"); throw new ND4JIllegalStateException("All arrays should have equal length for averaging");
AllocationPoint point = allocator.getAllocationPoint(arrays[i]); AllocationPoint point = allocator.getAllocationPoint(arrays[i]);
xPointers[i] = point.getPointers().getDevicePointer().address(); xPointers[i] = point.getDevicePointer().address();
point.tickDeviceWrite(); point.tickDeviceWrite();
} }
@ -710,7 +692,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
throw new ND4JIllegalStateException("All arrays should have equal length for averaging"); throw new ND4JIllegalStateException("All arrays should have equal length for averaging");
AllocationPoint point = allocator.getAllocationPoint(arrays[i]); AllocationPoint point = allocator.getAllocationPoint(arrays[i]);
xPointers[i] = point.getPointers().getDevicePointer().address(); xPointers[i] = point.getDevicePointer().address();
point.tickDeviceWrite(); point.tickDeviceWrite();
} }
@ -1324,11 +1306,11 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
PointerPointer extraz = new PointerPointer(null, // not used PointerPointer extraz = new PointerPointer(null, // not used
context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer()); context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer());
val x = ((BaseCudaDataBuffer) tensor.data()).getOpaqueDataBuffer();
nativeOps.tear(extraz, nativeOps.tear(extraz,
null, x, (LongPointer) tensor.shapeInfoDataBuffer().addressPointer(), (LongPointer) AtomicAllocator.getInstance().getPointer(tensor.shapeInfoDataBuffer(), context),
(LongPointer) tensor.shapeInfoDataBuffer().addressPointer(),
AtomicAllocator.getInstance().getPointer(tensor, context),
(LongPointer) AtomicAllocator.getInstance().getPointer(tensor.shapeInfoDataBuffer(), context),
new PointerPointer(AtomicAllocator.getInstance().getPointer(tempX, context)), new PointerPointer(AtomicAllocator.getInstance().getPointer(tempX, context)),
(LongPointer) AtomicAllocator.getInstance().getPointer(result[0].shapeInfoDataBuffer(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(result[0].shapeInfoDataBuffer(), context),
(LongPointer) AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context),

View File

@ -46,6 +46,10 @@ public class CudaBfloat16DataBuffer extends BaseCudaDataBuffer {
super(pointer, specialPointer, indexer, length); super(pointer, specialPointer, indexer, length);
} }
public CudaBfloat16DataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) {
super(buffer, dataType, length, offset);
}
/** /**
* Base constructor * Base constructor
* *
@ -128,18 +132,6 @@ public class CudaBfloat16DataBuffer extends BaseCudaDataBuffer {
super(data, copy, offset); super(data, copy, offset);
} }
public CudaBfloat16DataBuffer(byte[] data, long length) {
super(data, length, DataType.BFLOAT16);
}
public CudaBfloat16DataBuffer(ByteBuffer buffer, long length) {
super(buffer, (int) length, DataType.BFLOAT16);
}
public CudaBfloat16DataBuffer(ByteBuffer buffer, long length, long offset) {
super(buffer, length, offset, DataType.BFLOAT16);
}
@Override @Override
public void assign(long[] indices, double[] data, boolean contiguous, long inc) { public void assign(long[] indices, double[] data, boolean contiguous, long inc) {

View File

@ -50,6 +50,10 @@ public class CudaBoolDataBuffer extends BaseCudaDataBuffer {
super(pointer, specialPointer, indexer, length); super(pointer, specialPointer, indexer, length);
} }
public CudaBoolDataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) {
super(buffer, dataType, length, offset);
}
/** /**
* Base constructor * Base constructor
* *
@ -132,18 +136,6 @@ public class CudaBoolDataBuffer extends BaseCudaDataBuffer {
super(data, copy, offset); super(data, copy, offset);
} }
public CudaBoolDataBuffer(byte[] data, long length) {
super(data, length, DataType.HALF);
}
public CudaBoolDataBuffer(ByteBuffer buffer, long length) {
super(buffer, (int) length, DataType.HALF);
}
public CudaBoolDataBuffer(ByteBuffer buffer, long length, long offset) {
super(buffer, length, offset, DataType.HALF);
}
@Override @Override
protected DataBuffer create(long length) { protected DataBuffer create(long length) {
return new CudaBoolDataBuffer(length); return new CudaBoolDataBuffer(length);

View File

@ -49,6 +49,10 @@ public class CudaByteDataBuffer extends BaseCudaDataBuffer {
super(pointer, specialPointer, indexer, length); super(pointer, specialPointer, indexer, length);
} }
public CudaByteDataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) {
super(buffer, dataType, length, offset);
}
/** /**
* Base constructor * Base constructor
* *
@ -131,18 +135,6 @@ public class CudaByteDataBuffer extends BaseCudaDataBuffer {
super(data, copy, offset); super(data, copy, offset);
} }
public CudaByteDataBuffer(byte[] data, long length) {
super(data, length, DataType.HALF);
}
public CudaByteDataBuffer(ByteBuffer buffer, long length) {
super(buffer, (int) length, DataType.HALF);
}
public CudaByteDataBuffer(ByteBuffer buffer, long length, long offset) {
super(buffer, length, offset, DataType.HALF);
}
@Override @Override
protected DataBuffer create(long length) { protected DataBuffer create(long length) {
return new CudaByteDataBuffer(length); return new CudaByteDataBuffer(length);

View File

@ -49,6 +49,10 @@ public class CudaDoubleDataBuffer extends BaseCudaDataBuffer {
super(pointer, specialPointer, indexer, length); super(pointer, specialPointer, indexer, length);
} }
public CudaDoubleDataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) {
super(buffer, dataType, length, offset);
}
/** /**
* Base constructor * Base constructor
* *
@ -138,18 +142,6 @@ public class CudaDoubleDataBuffer extends BaseCudaDataBuffer {
super(data, copy, offset); super(data, copy, offset);
} }
public CudaDoubleDataBuffer(byte[] data, long length) {
super(data, length, DataType.DOUBLE);
}
public CudaDoubleDataBuffer(ByteBuffer buffer, long length) {
super(buffer, (int) length, DataType.DOUBLE);
}
public CudaDoubleDataBuffer(ByteBuffer buffer, long length, long offset) {
super(buffer, length, offset, DataType.DOUBLE);
}
@Override @Override
protected DataBuffer create(long length) { protected DataBuffer create(long length) {
return new CudaDoubleDataBuffer(length); return new CudaDoubleDataBuffer(length);
@ -210,14 +202,7 @@ public class CudaDoubleDataBuffer extends BaseCudaDataBuffer {
this.length = n; this.length = n;
this.elementSize = 8; this.elementSize = 8;
//wrappedBuffer = ByteBuffer.allocateDirect(length() * getElementSize()); this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, DataType.DOUBLE), false);
//wrappedBuffer.order(ByteOrder.nativeOrder());
this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this,
new AllocationShape(length, elementSize, DataType.DOUBLE), false);
this.trackingPoint = allocationPoint.getObjectId();
//this.wrappedBuffer = allocationPoint.getPointers().getHostPointer().asByteBuffer();
//this.wrappedBuffer.order(ByteOrder.nativeOrder());
setData(arr); setData(arr);
} }

View File

@ -50,6 +50,10 @@ public class CudaFloatDataBuffer extends BaseCudaDataBuffer {
super(pointer, specialPointer, indexer, length); super(pointer, specialPointer, indexer, length);
} }
public CudaFloatDataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) {
super(buffer, dataType, length, offset);
}
/** /**
* Base constructor * Base constructor
* *
@ -133,19 +137,6 @@ public class CudaFloatDataBuffer extends BaseCudaDataBuffer {
super(data, copy, offset); super(data, copy, offset);
} }
public CudaFloatDataBuffer(byte[] data, long length) {
super(data, length, DataType.FLOAT);
}
public CudaFloatDataBuffer(ByteBuffer buffer, long length) {
super(buffer, (int) length, DataType.FLOAT);
}
public CudaFloatDataBuffer(ByteBuffer buffer, long length, long offset) {
super(buffer, length, offset, DataType.FLOAT);
}
@Override @Override
protected DataBuffer create(long length) { protected DataBuffer create(long length) {
return new CudaFloatDataBuffer(length); return new CudaFloatDataBuffer(length);

Some files were not shown because too many files have changed in this diff Show More