String changes (#3)

* initial commit

* additional data types & tensor type

Signed-off-by: raver119 <raver119@gmail.com>

* next step

Signed-off-by: raver119 <raver119@gmail.com>

* missing include

* sparse_to_dense

Signed-off-by: raver119 <raver119@gmail.com>

* few more tests files

Signed-off-by: raver119 <raver119@gmail.com>

* draft

Signed-off-by: raver119 <raver119@gmail.com>

* numeric sparse_to_dense

Signed-off-by: raver119 <raver119@gmail.com>

* comment

Signed-off-by: raver119 <raver119@gmail.com>

* string sparse_to_dense version

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA DataBuffer expand

Signed-off-by: raver119 <raver119@gmail.com>

* few tweaks for CUDA build

Signed-off-by: raver119 <raver119@gmail.com>

* shape fn for string_split

Signed-off-by: raver119 <raver119@gmail.com>

* one more comment

Signed-off-by: raver119 <raver119@gmail.com>

* string_split indices

Signed-off-by: raver119 <raver119@gmail.com>

* next step

Signed-off-by: raver119 <raver119@gmail.com>

* test passes

Signed-off-by: raver119 <raver119@gmail.com>

* few rearrangements for databuffer implementations

Signed-off-by: raver119 <raver119@gmail.com>

* DataBuffer: move inline methods to common implementations

Signed-off-by: raver119 <raver119@gmail.com>

* add native DataBuffer to Nd4j presets

Signed-off-by: raver119 <raver119@gmail.com>

* DataBuffer creation

Signed-off-by: raver119 <raver119@gmail.com>

* use DataBuffer for allocation

Signed-off-by: raver119 <raver119@gmail.com>

* cpu databuffer as deallocatable

Signed-off-by: raver119 <raver119@gmail.com>

* DataBuffer setters for bufers

Signed-off-by: raver119 <raver119@gmail.com>

* couple of wrappers

Signed-off-by: raver119 <raver119@gmail.com>

* DataBuffers being passed around

Signed-off-by: raver119 <raver119@gmail.com>

* Bunch of ByteBuffer-related signatures gone

Signed-off-by: raver119 <raver119@gmail.com>

* - few more Nd4j signatures removed
- minor fix for bfloat16

Signed-off-by: raver119 <raver119@gmail.com>

* nullptr pointer is still a pointer, but 0 as address :)

Signed-off-by: raver119 <raver119@gmail.com>

* one special test

Signed-off-by: raver119 <raver119@gmail.com>

* empty string array init

Signed-off-by: raver119 <raver119@gmail.com>

* one more test in cpp

Signed-off-by: raver119 <raver119@gmail.com>

* memcpy instead of databuffer swap

Signed-off-by: raver119 <raver119@gmail.com>

* special InteropDataBuffer for front-end languages

Signed-off-by: raver119 <raver119@gmail.com>

* few tweaks for java

Signed-off-by: raver119 <raver119@gmail.com>

* pointer/indexer actualization

Signed-off-by: raver119 <raver119@gmail.com>

* CustomOp returns list for inputArumgents and outputArguments instead of array

Signed-off-by: raver119 <raver119@gmail.com>

* redundant call

Signed-off-by: raver119 <raver119@gmail.com>

* print_variable op

Signed-off-by: raver119 <raver119@gmail.com>

* - view handling (but wrong one)
- print_variable java wrapper

Signed-off-by: raver119 <raver119@gmail.com>

* one more test

Signed-off-by: raver119 <raver119@gmail.com>

* - empty arrays handling

Signed-off-by: raver119 <raver119@gmail.com>

* - deserialization works now

Signed-off-by: raver119 <raver119@gmail.com>

* minor fix

Signed-off-by: raver119 <raver119@gmail.com>

* meh

Signed-off-by: raver119 <raver119@gmail.com>

* one more fix

Signed-off-by: raver119 <raver119@gmail.com>

* initial cuda commit

Signed-off-by: raver119 <raver119@gmail.com>

* print_variable message validation

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA views

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA special buffer size

Signed-off-by: raver119 <raver119@gmail.com>

* minor update to match master changes

Signed-off-by: raver119 <raver119@gmail.com>

* - consider arrays always actual on device for CUDA
- additional PrintVariable constructor
- CudaUtf8Buffer now allocates host buffer by default

Signed-off-by: raver119 <raver119@gmail.com>

* meh

Signed-off-by: raver119 <raver119@gmail.com>

* - print_variable now allows print from device

Signed-off-by: raver119 <raver119@gmail.com>

* InteropDataBuffer data type fix

Signed-off-by: raver119 <raver119@gmail.com>

* ...

Signed-off-by: raver119 <raver119@gmail.com>

* disable some debug messages

Signed-off-by: raver119 <raver119@gmail.com>

* master pulled in

Signed-off-by: raver119 <raver119@gmail.com>

* couple of new methods for DataBuffer interop

Signed-off-by: raver119 <raver119@gmail.com>

* java side

Signed-off-by: raver119 <raver119@gmail.com>

* offsetted constructor

Signed-off-by: raver119 <raver119@gmail.com>

* new CUDA deallocator

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA backend torn apart

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA backend torn apart 2

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA backend torn apart 3

Signed-off-by: raver119 <raver119@gmail.com>

* - few new tests
- few new methods for DataBuffer management

Signed-off-by: raver119 <raver119@gmail.com>

* few more tests + few more tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* two failing tests

Signed-off-by: raver119 <raver119@gmail.com>

* one more test

Signed-off-by: raver119 <raver119@gmail.com>

* two failing tests pass

Signed-off-by: raver119 <raver119@gmail.com>

* now we pass DataBuffer to legacy ops too

Signed-off-by: raver119 <raver119@gmail.com>

* Native DataBuffer for legacy ops, Java side

Signed-off-by: raver119 <raver119@gmail.com>

* CPU java side update

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA java side update

Signed-off-by: raver119 <raver119@gmail.com>

* no more prepare/register action on java side

Signed-off-by: raver119 <raver119@gmail.com>

* NDArray::prepare/register use now accepts vectors

Signed-off-by: raver119 <raver119@gmail.com>

* InteropDataBuffer now has few more convenience methods

Signed-off-by: raver119 <raver119@gmail.com>

* java bindings update

Signed-off-by: raver119 <raver119@gmail.com>

* tick device in NativeOps

Signed-off-by: raver119 <raver119@gmail.com>

* Corrected usage of OpaqueBuffer for tests.

* Corrected usage of OpaqueBuffer for java tests.

* NativeOpsTests fixes.

* print_variable now returns scalar

Signed-off-by: raver119 <raver119@gmail.com>

* one more test

Signed-off-by: raver119 <raver119@gmail.com>

* compat_string_split fix for CUDA

Signed-off-by: raver119 <raver119@gmail.com>

* - CUDA execScalar fix
- CUDA lazyAllocateHostPointer now checks java indexer/pointer instead of native pointer

Signed-off-by: raver119 <raver119@gmail.com>

* legacy ops DataBuffer migration prototype

Signed-off-by: raver119 <raver119@gmail.com>

* ignore device shapeinfo coming from java

Signed-off-by: raver119 <raver119@gmail.com>

* minor fix

Signed-off-by: raver119 <raver119@gmail.com>

* minor transformAny fix

Signed-off-by: raver119 <raver119@gmail.com>

* minor tweak for lazy host allocation

Signed-off-by: raver119 <raver119@gmail.com>

* - DataBuffer::memcpy method
- bitcast now uses memcpy

Signed-off-by: raver119 <raver119@gmail.com>

* - IndexReduce CUDA dimension buffer fix

Signed-off-by: raver119 <raver119@gmail.com>

* views for CPU and CUDA

Signed-off-by: raver119 <raver119@gmail.com>

* less spam

Signed-off-by: raver119 <raver119@gmail.com>

* optional memory init

Signed-off-by: raver119 <raver119@gmail.com>

* async memset

Signed-off-by: raver119 <raver119@gmail.com>

* - SummaryStats CUDA fix
- DataBuffer.sameUnderlyingData() impl
- execBroadcast fix

Signed-off-by: raver119 <raver119@gmail.com>

* - reduce3All fix
switch to CUDA 10 temporarily

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA version

Signed-off-by: raver119 <raver119@gmail.com>

* proper memory deallocator registration

Signed-off-by: raver119 <raver119@gmail.com>

* HOST_ONLY workspace allocation

Signed-off-by: raver119 <raver119@gmail.com>

* temp commit

Signed-off-by: raver119 <raver119@gmail.com>

* few conflicts resolved

Signed-off-by: raver119 <raver119@gmail.com>

* few minor fixes

Signed-off-by: raver119 <raver119@gmail.com>

* one more minor fix

Signed-off-by: raver119 <raver119@gmail.com>

* NDArray permute should operate on JVM primitives

Signed-off-by: raver119 <raver119@gmail.com>

* - create InteropDataBuffer for shapes as well
- update pointers after view creation in Java

Signed-off-by: raver119 <raver119@gmail.com>

* - addressPointer temporary moved to C++

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA: don't account offset twice

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA: DataBuffer pointer constructor updated

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA NDArray.unsafeDuplication() simplified

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA minor workspace-related fixes

Signed-off-by: raver119 <raver119@gmail.com>

* CPU DataBuffer.reallocate()

Signed-off-by: raver119 <raver119@gmail.com>

* print_affinity op

Signed-off-by: raver119 <raver119@gmail.com>

* print_affinity java side

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA more tweaks for data locality

Signed-off-by: raver119 <raver119@gmail.com>

* - compat_string_split tweak
- CudaUtf8Buffer update

Signed-off-by: raver119 <raver119@gmail.com>

* INDArray.close() mechanic restored

Signed-off-by: raver119 <raver119@gmail.com>

* one more test fixed

Signed-off-by: raver119 <raver119@gmail.com>

* - CUDA DataBuffer.reallocate() updated
- cudaMemcpy (synchronous) restored

Signed-off-by: raver119 <raver119@gmail.com>

* one last fix

Signed-off-by: raver119 <raver119@gmail.com>

* bad import removed

Signed-off-by: raver119 <raver119@gmail.com>

* another small fix

Signed-off-by: raver119 <raver119@gmail.com>

* one special test

Signed-off-by: raver119 <raver119@gmail.com>

* fix bad databuffer size

Signed-off-by: raver119 <raver119@gmail.com>

* release primaryBuffer on replace

Signed-off-by: raver119 <raver119@gmail.com>

* higher timeout

Signed-off-by: raver119 <raver119@gmail.com>

* disable timeouts

Signed-off-by: raver119 <raver119@gmail.com>

* dbCreateView now validates offset and length of a view

Signed-off-by: raver119 <raver119@gmail.com>

* additional validation for dbExpand

Signed-off-by: raver119 <raver119@gmail.com>

* restore timeout back again

Signed-off-by: raver119 <raver119@gmail.com>

* smaller distribution for rng test to prevent timeouts

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA DataBuffer::memcpy now copies to device all the time

Signed-off-by: raver119 <raver119@gmail.com>

* OpaqueDataBuffer now contains all required methods for interop

Signed-off-by: raver119 <raver119@gmail.com>

* some javadoc

Signed-off-by: raver119 <raver119@gmail.com>

* GC on failed allocations

Signed-off-by: raver119 <raver119@gmail.com>

* minoe memcpu tweak

Signed-off-by: raver119 <raver119@gmail.com>

* one more bitcast test

Signed-off-by: raver119 <raver119@gmail.com>

* - NDArray::deviceId() propagation
- special multi-threaded test for data locality checks

Signed-off-by: raver119 <raver119@gmail.com>

* DataBuffer additional syncStream

Signed-off-by: raver119 <raver119@gmail.com>

* DataBuffer additional syncStream

Signed-off-by: raver119 <raver119@gmail.com>

* one ignored test

Signed-off-by: raver119 <raver119@gmail.com>

* skip host alloc for empty arrays

Signed-off-by: raver119 <raver119@gmail.com>

* ByteBuffer support is back

Signed-off-by: raver119 <raver119@gmail.com>

* DataBuffer::memcpy minor fix

Signed-off-by: raver119 <raver119@gmail.com>

* few minor prelu/bp tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* nullify-related fixes

Signed-off-by: raver119 <raver119@gmail.com>

* PReLU fixes (#157)

Signed-off-by: Alex Black <blacka101@gmail.com>

* Build fixed

* Fix tests

* one more ByteBuffer signature restored

Signed-off-by: raver119 <raver119@gmail.com>

* nd4j-jdbc-hsql profiles fix

Signed-off-by: raver119 <raver119@gmail.com>

* nd4j-jdbc-hsql profiles fix

Signed-off-by: raver119 <raver119@gmail.com>

* PReLU weight init fix

Signed-off-by: Alex Black <blacka101@gmail.com>

* Small PReLU fix

Signed-off-by: Alex Black <blacka101@gmail.com>

* - INDArray.migrate() reactivated
- DataBuffer::setDeviceId(...) added
- InteropDataBuffer Z syncToDevice added for views

Signed-off-by: raver119 <raver119@gmail.com>

* missed file

Signed-off-by: raver119 <raver119@gmail.com>

* Small tweak

Signed-off-by: Alex Black <blacka101@gmail.com>

* cuda 10.2

Signed-off-by: raver119 <raver119@gmail.com>

* minor fix

Signed-off-by: raver119 <raver119@gmail.com>

Co-authored-by: shugeo <sgazeos@gmail.com>
Co-authored-by: Alex Black <blacka101@gmail.com>
Co-authored-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
master
raver119 2020-01-04 13:27:50 +03:00 committed by GitHub
parent 451d9d57fd
commit 29e8e09db6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
169 changed files with 8463 additions and 7839 deletions

View File

@ -121,6 +121,7 @@ public class PReLULayer extends BaseLayer {
public static class Builder extends FeedForwardLayer.Builder<PReLULayer.Builder> {
public Builder(){
//Default to 0s, and don't inherit global default
this.weightInitFn = new WeightInitConstant(0);
}

View File

@ -20,7 +20,7 @@ import lombok.Getter;
import lombok.NonNull;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.linalg.api.buffer.FloatBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -63,7 +63,7 @@ public class NegativeHolder implements Serializable {
protected void makeTable(int tableSize, double power) {
int vocabSize = vocab.numWords();
table = Nd4j.create(new FloatBuffer(tableSize));
table = Nd4j.create(DataType.FLOAT, tableSize);
double trainWordsPow = 0.0;
for (String word : vocab.words()) {
trainWordsPow += Math.pow(vocab.wordFrequency(word), power);

View File

@ -42,6 +42,8 @@
#include <helpers/ConstantShapeHelper.h>
#include <array/DataBuffer.h>
#include <execution/AffinityManager.h>
#include <memory>
#include <array/InteropDataBuffer.h>
namespace nd4j {
@ -301,14 +303,11 @@ namespace nd4j {
* @param writeList
* @param readList
*/
// TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list
static void registerSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList);
static void prepareSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables = false);
// TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list
static void registerPrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList);
static void preparePrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables = false);
static void registerSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList);
static void prepareSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables = false);
static void registerPrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList);
static void preparePrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables = false);
/**
* This method returns buffer pointer offset by given number of elements, wrt own data type

View File

@ -223,6 +223,8 @@ NDArray::NDArray(std::shared_ptr<DataBuffer> buffer, const ShapeDescriptor& desc
setShapeInfo(descriptor);
_buffer = buffer;
_isView = offset > 0 || _length * DataTypeUtils::sizeOf(_dataType) < buffer->getLenInBytes();
}
////////////////////////////////////////////////////////////////////////
@ -288,6 +290,8 @@ NDArray::NDArray(std::shared_ptr<DataBuffer> buffer, const char order, const std
setShapeInfo(ShapeDescriptor(buffer->getDataType(), order, shape));
_buffer = buffer;
_isView = _length * DataTypeUtils::sizeOf(_dataType) < buffer->getLenInBytes();
}
////////////////////////////////////////////////////////////////////////

View File

@ -68,6 +68,7 @@ bool verbose = false;
#include <array/ConstantDescriptor.h>
#include <helpers/ConstantShapeHelper.h>
#include <array/ConstantDataBuffer.h>
#include <array/InteropDataBuffer.h>
#include <helpers/ConstantHelper.h>
#include <array/TadPack.h>
#include <graph/VariablesSet.h>
@ -76,6 +77,8 @@ bool verbose = false;
#include <graph/ResultWrapper.h>
#include <DebugInfo.h>
typedef nd4j::InteropDataBuffer OpaqueDataBuffer;
extern "C" {
/**
@ -118,11 +121,9 @@ ND4J_EXPORT void setTADThreshold(int num);
*/
ND4J_EXPORT void execIndexReduceScalar(Nd4jPointer *extraPointers,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo);
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo);
/**
*
@ -137,13 +138,10 @@ ND4J_EXPORT void execIndexReduceScalar(Nd4jPointer *extraPointers,
*/
ND4J_EXPORT void execIndexReduce(Nd4jPointer *extraPointers,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape);
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape);
/**
*
@ -160,28 +158,20 @@ ND4J_EXPORT void execIndexReduce(Nd4jPointer *extraPointers,
ND4J_EXPORT void execBroadcast(
Nd4jPointer *extraPointers,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape);
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape);
ND4J_EXPORT void execBroadcastBool(
Nd4jPointer *extraPointers,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
void *extraParams,
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape);
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape);
/**
*
@ -198,23 +188,17 @@ ND4J_EXPORT void execBroadcastBool(
ND4J_EXPORT void execPairwiseTransform(
Nd4jPointer *extraPointers,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
void *extraParams);
ND4J_EXPORT void execPairwiseTransformBool(
Nd4jPointer *extraPointers,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
void *extraParams);
/**
@ -228,36 +212,28 @@ ND4J_EXPORT void execPairwiseTransformBool(
*/
ND4J_EXPORT void execReduceFloat(Nd4jPointer *extraPointers,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo);
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo);
ND4J_EXPORT void execReduceSame(Nd4jPointer *extraPointers,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo);
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo);
ND4J_EXPORT void execReduceBool(Nd4jPointer *extraPointers,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo);
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo);
ND4J_EXPORT void execReduceLong(Nd4jPointer *extraPointers,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo);
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo);
/**
*
@ -270,46 +246,34 @@ ND4J_EXPORT void execReduceLong(Nd4jPointer *extraPointers,
*/
ND4J_EXPORT void execReduceFloat2(Nd4jPointer *extraPointers,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape);
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape);
ND4J_EXPORT void execReduceSame2(Nd4jPointer *extraPointers,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape);
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape);
ND4J_EXPORT void execReduceBool2(Nd4jPointer *extraPointers,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape);
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape);
ND4J_EXPORT void execReduceLong2(Nd4jPointer *extraPointers,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape);
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape);
/**
*
@ -324,13 +288,10 @@ ND4J_EXPORT void execReduceLong2(Nd4jPointer *extraPointers,
*/
ND4J_EXPORT void execReduce3(Nd4jPointer *extraPointers,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *extraParamsVals,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo);
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo);
/**
*
@ -343,13 +304,10 @@ ND4J_EXPORT void execReduce3(Nd4jPointer *extraPointers,
*/
ND4J_EXPORT void execReduce3Scalar(Nd4jPointer *extraPointers,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *extraParamsVals,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo);
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo);
/**
*
* @param opNum
@ -365,30 +323,22 @@ ND4J_EXPORT void execReduce3Scalar(Nd4jPointer *extraPointers,
*/
ND4J_EXPORT void execReduce3Tad(Nd4jPointer *extraPointers,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *extraParamsVals,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *yTadOnlyShapeInfo, Nd4jLong *yTadOffsets);
ND4J_EXPORT void execReduce3All(Nd4jPointer *extraPointers,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *extraParamsVals,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape,
Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets,
Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets);
@ -405,22 +355,16 @@ ND4J_EXPORT void execReduce3All(Nd4jPointer *extraPointers,
*/
ND4J_EXPORT void execScalar(Nd4jPointer *extraPointers,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hScalar, Nd4jLong *hSscalarShapeInfo,
void *dScalar, Nd4jLong *dSscalarShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbScalar, Nd4jLong *hSscalarShapeInfo, Nd4jLong *dSscalarShapeInfo,
void *extraParams);
ND4J_EXPORT void execScalarBool(Nd4jPointer *extraPointers,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hScalar, Nd4jLong *hSscalarShapeInfo,
void *dScalar, Nd4jLong *dSscalarShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbScalar, Nd4jLong *hSscalarShapeInfo, Nd4jLong *dSscalarShapeInfo,
void *extraParams);
/**
@ -432,11 +376,9 @@ ND4J_EXPORT void execScalarBool(Nd4jPointer *extraPointers,
*/
ND4J_EXPORT void execSummaryStatsScalar(Nd4jPointer *extraPointers,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
bool biasCorrected);
/**
*
@ -449,11 +391,9 @@ ND4J_EXPORT void execSummaryStatsScalar(Nd4jPointer *extraPointers,
*/
ND4J_EXPORT void execSummaryStats(Nd4jPointer *extraPointers,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
bool biasCorrected);
/**
*
@ -468,13 +408,10 @@ ND4J_EXPORT void execSummaryStats(Nd4jPointer *extraPointers,
*/
ND4J_EXPORT void execSummaryStatsTad(Nd4jPointer *extraPointers,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape,
bool biasCorrected,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
@ -490,42 +427,32 @@ ND4J_EXPORT void execSummaryStatsTad(Nd4jPointer *extraPointers,
*/
ND4J_EXPORT void execTransformFloat(Nd4jPointer *extraPointers,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
void *extraParams);
ND4J_EXPORT void execTransformSame(Nd4jPointer *extraPointers,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
void *extraParams);
ND4J_EXPORT void execTransformBool(Nd4jPointer *extraPointers,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
void *extraParams);
ND4J_EXPORT void execTransformAny(Nd4jPointer *extraPointers,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
void *extraParams);
ND4J_EXPORT void execTransformStrict(Nd4jPointer *extraPointers,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
void *extraParams);
/**
@ -543,29 +470,21 @@ ND4J_EXPORT void execTransformStrict(Nd4jPointer *extraPointers,
*/
ND4J_EXPORT void execScalarTad(Nd4jPointer *extraPointers,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hScalars, Nd4jLong *hScalarShapeInfo,
void *dScalars, Nd4jLong *dScalarShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbScalars, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo,
void *extraParams,
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);
ND4J_EXPORT void execScalarBoolTad(Nd4jPointer *extraPointers,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hScalars, Nd4jLong *hScalarShapeInfo,
void *dScalars, Nd4jLong *dScalarShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbScalars, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo,
void *extraParams,
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape,
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);
@ -904,10 +823,8 @@ ND4J_EXPORT void deleteTadPack(OpaqueTadPack* ptr);
* @param zTadOffsets
*/
ND4J_EXPORT void pullRows(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo,
void *z, Nd4jLong *zShapeInfo,
void *dz, Nd4jLong *dzShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong *xShapeInfo, Nd4jLong *dxShapeInfo,
OpaqueDataBuffer *dbZ, Nd4jLong *zShapeInfo, Nd4jLong *dzShapeInfo,
Nd4jLong n,
Nd4jLong *indexes,
Nd4jLong *tadShapeInfo,
@ -1086,8 +1003,7 @@ ND4J_EXPORT void execAggregateBatch(Nd4jPointer *extraPointers,
ND4J_EXPORT void execRandom(Nd4jPointer *extraPointers,
int opNum,
Nd4jPointer state,
void *hZ, Nd4jLong *hZShapeBuffer,
void *dZ, Nd4jLong *dZShapeBuffer,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeBuffer, Nd4jLong *dZShapeBuffer,
void *extraArguments);
/**
@ -1106,12 +1022,9 @@ ND4J_EXPORT void execRandom(Nd4jPointer *extraPointers,
ND4J_EXPORT void execRandom3(Nd4jPointer *extraPointers,
int opNum,
Nd4jPointer state,
void *hX, Nd4jLong *hXShapeBuffer,
void *dX, Nd4jLong *dXShapeBuffer,
void *hY, Nd4jLong *hYShapeBuffer,
void *dY, Nd4jLong *dYShapeBuffer,
void *hZ, Nd4jLong *hZShapeBuffer,
void *dZ, Nd4jLong *dZShapeBuffer,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeBuffer, Nd4jLong *dXShapeBuffer,
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeBuffer, Nd4jLong *dYShapeBuffer,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeBuffer, Nd4jLong *dZShapeBuffer,
void *extraArguments);
/**
@ -1128,10 +1041,8 @@ ND4J_EXPORT void execRandom3(Nd4jPointer *extraPointers,
ND4J_EXPORT void execRandom2(Nd4jPointer *extraPointers,
int opNum,
Nd4jPointer state,
void *hX, Nd4jLong *hXShapeBuffer,
void *dX, Nd4jLong *dXShapeBuffer,
void *hZ, Nd4jLong *hZShapeBuffer,
void *dZ, Nd4jLong *dZShapeBuffer,
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeBuffer, Nd4jLong *dXShapeBuffer,
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeBuffer, Nd4jLong *dZShapeBuffer,
void *extraArguments);
@ -1174,52 +1085,6 @@ ND4J_EXPORT void reSeedBuffer(Nd4jPointer *extraPointers,
*/
ND4J_EXPORT void destroyRandom(Nd4jPointer ptrRandom);
/**
* Grid operations
*/
/**
*
* @param extras
* @param opTypeA
* @param opNumA
* @param opTypeB
* @param opNumB
* @param N
* @param dx
* @param xShapeInfo
* @param dy
* @param yShapeInfo
* @param dz
* @param zShapeInfo
* @param extraA
* @param extraB
* @param scalarA
* @param scalarB
*/
/*
ND4J_EXPORT void execMetaPredicateShape(Nd4jPointer *extras,
const int opTypeA,
const int opNumA,
const int opTypeB,
const int opNumB,
Nd4jLong N,
void *hX, Nd4jLong *hXShapeBuffer,
void *dX, Nd4jLong *dXShapeBuffer,
void *hY, Nd4jLong *hYShapeBuffer,
void *dY, Nd4jLong *dYShapeBuffer,
void *hZ, Nd4jLong *hZShapeBuffer,
void *dZ, Nd4jLong *dZShapeBuffer,
void *extraA,
void *extraB,
double scalarA,
double scalarB);
*/
}
/**
@ -1561,8 +1426,7 @@ ND4J_EXPORT Nd4jPointer pointerForAddress(Nd4jLong address);
* @return
*/
ND4J_EXPORT void tear(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo,
OpaqueDataBuffer *dbX, Nd4jLong *xShapeInfo, Nd4jLong *dxShapeInfo,
Nd4jPointer *targets, Nd4jLong *zShapeInfo,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffsets);
@ -1739,6 +1603,8 @@ ND4J_EXPORT void markGraphContextInplace(OpaqueContext* ptr, bool reallyInplace)
ND4J_EXPORT void setGraphContextCudaContext(OpaqueContext* ptr, void *stream, void *reductionPointer, void *allocationPointer);
ND4J_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
ND4J_EXPORT void setGraphContextOutputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
ND4J_EXPORT void setGraphContextInputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo);
ND4J_EXPORT void setGraphContextOutputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo);
ND4J_EXPORT void setGraphContextTArguments(OpaqueContext* ptr, double *arguments, int numberOfArguments);
ND4J_EXPORT void setGraphContextIArguments(OpaqueContext* ptr, Nd4jLong *arguments, int numberOfArguments);
ND4J_EXPORT void setGraphContextBArguments(OpaqueContext* ptr, bool *arguments, int numberOfArguments);
@ -1766,6 +1632,28 @@ ND4J_EXPORT Nd4jPointer lcCopyStream(OpaqueLaunchContext* lc);
ND4J_EXPORT Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc);
ND4J_EXPORT Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc);
ND4J_EXPORT OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth);
ND4J_EXPORT OpaqueDataBuffer* dbCreateView(OpaqueDataBuffer *dataBuffer, Nd4jLong length, Nd4jLong offset);
ND4J_EXPORT Nd4jPointer dbPrimaryBuffer(OpaqueDataBuffer *dataBuffer);
ND4J_EXPORT Nd4jPointer dbSpecialBuffer(OpaqueDataBuffer *dataBuffer);
ND4J_EXPORT void dbExpandBuffer(OpaqueDataBuffer *dataBuffer, Nd4jLong elements);
ND4J_EXPORT void dbAllocatePrimaryBuffer(OpaqueDataBuffer *dataBuffer);
ND4J_EXPORT void dbAllocateSpecialBuffer(OpaqueDataBuffer *dataBuffer);
ND4J_EXPORT void dbSetPrimaryBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer primaryBuffer, Nd4jLong numBytes);
ND4J_EXPORT void dbSetSpecialBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer specialBuffer, Nd4jLong numBytes);
ND4J_EXPORT void dbSyncToSpecial(OpaqueDataBuffer *dataBuffer);
ND4J_EXPORT void dbSyncToPrimary(OpaqueDataBuffer *dataBuffer);
ND4J_EXPORT int dbLocality(OpaqueDataBuffer *dataBuffer);
ND4J_EXPORT int dbDeviceId(OpaqueDataBuffer *dataBuffer);
ND4J_EXPORT void dbSetDeviceId(OpaqueDataBuffer *dataBuffer, int deviceId);
ND4J_EXPORT void dbTickHostRead(OpaqueDataBuffer *dataBuffer);
ND4J_EXPORT void dbTickHostWrite(OpaqueDataBuffer *dataBuffer);
ND4J_EXPORT void dbTickDeviceRead(OpaqueDataBuffer *dataBuffer);
ND4J_EXPORT void dbTickDeviceWrite(OpaqueDataBuffer *dataBuffer);
ND4J_EXPORT void dbClose(OpaqueDataBuffer *dataBuffer);
ND4J_EXPORT void deleteDataBuffer(OpaqueDataBuffer *dataBuffer);
ND4J_EXPORT void dbExpand(OpaqueDataBuffer *dataBuffer, Nd4jLong elements);
ND4J_EXPORT int binaryLevel();
ND4J_EXPORT int optimalLevel();

View File

@ -184,16 +184,16 @@ void NDArray::synchronize(const char* msg) const {
// no-op
}
void NDArray::prepareSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) {
void NDArray::prepareSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables) {
// no-op
}
void NDArray::registerSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) {
void NDArray::registerSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList) {
// no-op
}
void NDArray::preparePrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) {
void NDArray::preparePrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables) {
// no-op
}
void NDArray::registerPrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) {
void NDArray::registerPrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList) {
// no-op
}

File diff suppressed because it is too large Load Diff

View File

@ -236,7 +236,7 @@ void NDArray::synchronize(const char* msg) const {
}
////////////////////////////////////////////////////////////////////////
void NDArray::prepareSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) {
void NDArray::prepareSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables) {
for (const auto& a : readList)
if(a != nullptr)
@ -252,7 +252,7 @@ void NDArray::prepareSpecialUse(const std::initializer_list<const NDArray*>& wri
}
////////////////////////////////////////////////////////////////////////
void NDArray::registerSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) {
void NDArray::registerSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList) {
for (const auto& p : readList)
if(p != nullptr)
@ -264,7 +264,7 @@ void NDArray::registerSpecialUse(const std::initializer_list<const NDArray*>& wr
}
////////////////////////////////////////////////////////////////////////
void NDArray::preparePrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) {
void NDArray::preparePrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables) {
for (const auto& a : readList)
if(a != nullptr)
@ -280,7 +280,7 @@ void NDArray::preparePrimaryUse(const std::initializer_list<const NDArray*>& wri
}
////////////////////////////////////////////////////////////////////////
void NDArray::registerPrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) {
void NDArray::registerPrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList) {
for (const auto& p : readList)
if(p != nullptr)

File diff suppressed because it is too large Load Diff

View File

@ -34,10 +34,12 @@
#define ARRAY_SPARSE 2
#define ARRAY_COMPRESSED 4
#define ARRAY_EMPTY 8
#define ARRAY_RAGGED 16
#define ARRAY_CSR 16
#define ARRAY_CSC 32
#define ARRAY_COO 64
#define ARRAY_CSR 32
#define ARRAY_CSC 64
#define ARRAY_COO 128
// complex values
#define ARRAY_COMPLEX 512
@ -72,8 +74,10 @@
// boolean values
#define ARRAY_BOOL 524288
// utf-8 values
#define ARRAY_STRING 1048576
// UTF values
#define ARRAY_UTF8 1048576
#define ARRAY_UTF16 4194304
#define ARRAY_UTF32 16777216
// flag for extras
#define ARRAY_EXTRAS 2097152
@ -173,8 +177,12 @@ namespace nd4j {
return nd4j::DataType ::UINT32;
else if (hasPropertyBitSet(shapeInfo, ARRAY_LONG))
return nd4j::DataType ::UINT64;
else if (hasPropertyBitSet(shapeInfo, ARRAY_STRING))
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF8))
return nd4j::DataType ::UTF8;
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF16))
return nd4j::DataType ::UTF16;
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF32))
return nd4j::DataType ::UTF32;
else {
//shape::printShapeInfoLinear("Bad unsigned datatype (not)stored in shape", const_cast<Nd4jLong*>(shapeInfo));
#ifndef __CUDA_ARCH__
@ -190,8 +198,12 @@ namespace nd4j {
return nd4j::DataType::INT32;
else if (hasPropertyBitSet(shapeInfo, ARRAY_LONG))
return nd4j::DataType::INT64;
else if (hasPropertyBitSet(shapeInfo, ARRAY_STRING))
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF8))
return nd4j::DataType::UTF8;
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF16))
return nd4j::DataType::UTF16;
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF32))
return nd4j::DataType::UTF32;
else {
//shape::printShapeInfoLinear("Bad signed datatype (not)stored in shape", const_cast<Nd4jLong*>(shapeInfo));
#ifndef __CUDA_ARCH__
@ -224,6 +236,8 @@ namespace nd4j {
return ArrayType::COMPRESSED;
else if (hasPropertyBitSet(shapeInfo, ARRAY_EMPTY))
return ArrayType::EMPTY;
else if (hasPropertyBitSet(shapeInfo, ARRAY_RAGGED))
return ArrayType::RAGGED;
else // by default we return DENSE type here
return ArrayType::DENSE;
}
@ -333,7 +347,13 @@ namespace nd4j {
setPropertyBit(shapeInfo, ARRAY_LONG);
break;
case nd4j::DataType::UTF8:
setPropertyBit(shapeInfo, ARRAY_STRING);
setPropertyBit(shapeInfo, ARRAY_UTF8);
break;
case nd4j::DataType::UTF16:
setPropertyBit(shapeInfo, ARRAY_UTF16);
break;
case nd4j::DataType::UTF32:
setPropertyBit(shapeInfo, ARRAY_UTF32);
break;
default:
#ifndef __CUDA_ARCH__

View File

@ -27,6 +27,7 @@ namespace nd4j {
SPARSE = 2,
COMPRESSED = 3,
EMPTY = 4,
RAGGED = 5,
};
}

View File

@ -36,13 +36,14 @@ class ND4J_EXPORT DataBuffer {
private:
void* _primaryBuffer;
void* _specialBuffer;
size_t _lenInBytes;
void* _primaryBuffer = nullptr;
void* _specialBuffer = nullptr;
size_t _lenInBytes = 0;
DataType _dataType;
memory::Workspace* _workspace;
memory::Workspace* _workspace = nullptr;
bool _isOwnerPrimary;
bool _isOwnerSpecial;
std::atomic<int> _deviceId;
#ifdef __CUDABLAS__
mutable std::atomic<Nd4jLong> _counter;
@ -55,9 +56,9 @@ class ND4J_EXPORT DataBuffer {
void setCountersToZero();
void copyCounters(const DataBuffer& other);
void deleteSpecial();
FORCEINLINE void deletePrimary();
FORCEINLINE void deleteBuffers();
FORCEINLINE void setAllocFlags(const bool isOwnerPrimary, const bool isOwnerSpecial = false);
void deletePrimary();
void deleteBuffers();
void setAllocFlags(const bool isOwnerPrimary, const bool isOwnerSpecial = false);
void allocateBuffers(const bool allocBoth = false);
void setSpecial(void* special, const bool isOwnerSpecial);
void copyBufferFromHost(const void* hostBuffer, size_t sizeToCopyinBytes = 0, const Nd4jLong offsetThis = 0, const Nd4jLong offsetHostBuffer = 0);
@ -65,37 +66,38 @@ class ND4J_EXPORT DataBuffer {
public:
FORCEINLINE DataBuffer(void* primary, void* special,
DataBuffer(void* primary, void* special,
const size_t lenInBytes, const DataType dataType,
const bool isOwnerPrimary = false, const bool isOwnerSpecial = false,
memory::Workspace* workspace = nullptr);
FORCEINLINE DataBuffer(void* primary,
DataBuffer(void* primary,
const size_t lenInBytes, const DataType dataType,
const bool isOwnerPrimary = false,
memory::Workspace* workspace = nullptr);
FORCEINLINE DataBuffer(const void* hostBuffer, // copies data from hostBuffer to own memory buffer
DataBuffer(const void* hostBuffer, // copies data from hostBuffer to own memory buffer
const DataType dataType, const size_t lenInBytes,
memory::Workspace* workspace = nullptr);
FORCEINLINE DataBuffer(const size_t lenInBytes, const DataType dataType, memory::Workspace* workspace = nullptr, const bool allocBoth = false);
DataBuffer(const size_t lenInBytes, const DataType dataType, memory::Workspace* workspace = nullptr, const bool allocBoth = false);
FORCEINLINE DataBuffer(const DataBuffer& other);
FORCEINLINE DataBuffer(DataBuffer&& other);
FORCEINLINE explicit DataBuffer();
FORCEINLINE ~DataBuffer();
DataBuffer(const DataBuffer& other);
DataBuffer(DataBuffer&& other);
explicit DataBuffer();
~DataBuffer();
FORCEINLINE DataBuffer& operator=(const DataBuffer& other);
FORCEINLINE DataBuffer& operator=(DataBuffer&& other) noexcept;
DataBuffer& operator=(const DataBuffer& other);
DataBuffer& operator=(DataBuffer&& other) noexcept;
FORCEINLINE DataType getDataType();
FORCEINLINE size_t getLenInBytes() const;
DataType getDataType();
void setDataType(DataType dataType);
size_t getLenInBytes() const;
FORCEINLINE void* primary();
FORCEINLINE void* special();
void* primary();
void* special();
FORCEINLINE void allocatePrimary();
void allocatePrimary();
void allocateSpecial();
void writePrimary() const;
@ -105,6 +107,10 @@ class ND4J_EXPORT DataBuffer {
bool isPrimaryActual() const;
bool isSpecialActual() const;
void expand(const uint64_t size);
int deviceId() const;
void setDeviceId(int deviceId);
void migrate();
template <typename T> FORCEINLINE T* primaryAsT();
@ -118,256 +124,28 @@ class ND4J_EXPORT DataBuffer {
void copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes = 0, const Nd4jLong offsetThis = 0, const Nd4jLong offsetOther = 0);
static void memcpy(const DataBuffer &dst, const DataBuffer &src);
void setPrimaryBuffer(void *buffer, size_t length);
void setSpecialBuffer(void *buffer, size_t length);
/**
* This method deletes buffers, if we're owners
*/
void close();
};
///// IMLEMENTATION OF INLINE METHODS /////
////////////////////////////////////////////////////////////////////////
// default constructor
DataBuffer::DataBuffer() {
_primaryBuffer = nullptr;
_specialBuffer = nullptr;
_lenInBytes = 0;
_dataType = INT8;
_workspace = nullptr;
_isOwnerPrimary = false;
_isOwnerSpecial = false;
setCountersToZero();
}
////////////////////////////////////////////////////////////////////////
// copy constructor
DataBuffer::DataBuffer(const DataBuffer &other) {
throw std::runtime_error("DataBuffer copy constructor: we don't expect using of this constructor!");
_lenInBytes = other._lenInBytes;
_dataType = other._dataType;
_workspace = other._workspace;
_primaryBuffer = nullptr;
_specialBuffer = nullptr;
setCountersToZero();
allocateBuffers();
copyBufferFrom(other);
}
////////////////////////////////////////////////////////////////////////
DataBuffer::DataBuffer(void* primary, void* special,
const size_t lenInBytes, const DataType dataType,
const bool isOwnerPrimary, const bool isOwnerSpecial,
memory::Workspace* workspace) {
if (primary == nullptr && special == nullptr)
throw std::runtime_error("DataBuffer constructor: can't be initialized with both nullptr buffers !");
_primaryBuffer = primary;
_specialBuffer = special;
_lenInBytes = lenInBytes;
_dataType = dataType;
_workspace = workspace;
_isOwnerPrimary = isOwnerPrimary;
_isOwnerSpecial = isOwnerSpecial;
setCountersToZero();
if(primary != nullptr)
readPrimary();
if(special != nullptr)
readSpecial();
}
////////////////////////////////////////////////////////////////////////
DataBuffer::DataBuffer(void* primary, const size_t lenInBytes, const DataType dataType, const bool isOwnerPrimary, memory::Workspace* workspace):
DataBuffer(primary, nullptr, lenInBytes, dataType, isOwnerPrimary, false, workspace) {
syncToSpecial(true);
}
////////////////////////////////////////////////////////////////////////
// copies data from hostBuffer to own memory buffer
DataBuffer::DataBuffer(const void* hostBuffer, const DataType dataType, const size_t lenInBytes, memory::Workspace* workspace) {
if (hostBuffer == nullptr)
throw std::runtime_error("DataBuffer constructor: can't be initialized with nullptr host buffer !");
if (lenInBytes == 0)
throw std::runtime_error("DataBuffer constructor: can't be initialized with zero length !");
_primaryBuffer = nullptr;
_specialBuffer = nullptr;
_lenInBytes = lenInBytes;
_dataType = dataType;
_workspace = workspace;
setCountersToZero();
allocateBuffers();
copyBufferFromHost(hostBuffer, lenInBytes);
}
////////////////////////////////////////////////////////////////////////
DataBuffer::DataBuffer(const size_t lenInBytes, const DataType dataType, memory::Workspace* workspace, const bool allocBoth) {
_dataType = dataType;
_workspace = workspace;
_lenInBytes = lenInBytes;
_primaryBuffer = nullptr;
_specialBuffer = nullptr;
setCountersToZero();
if(lenInBytes != 0) {
allocateBuffers(allocBoth);
writeSpecial();
}
}
////////////////////////////////////////////////////////////////////////
// move constructor
DataBuffer::DataBuffer(DataBuffer&& other) {
_primaryBuffer = other._primaryBuffer;
_specialBuffer = other._specialBuffer;
_lenInBytes = other._lenInBytes;
_dataType = other._dataType;
_workspace = other._workspace;
_isOwnerPrimary = other._isOwnerPrimary;
_isOwnerSpecial = other._isOwnerSpecial;
copyCounters(other);
other._primaryBuffer = other._specialBuffer = nullptr;
other.setAllocFlags(false, false);
other._lenInBytes = 0;
}
////////////////////////////////////////////////////////////////////////
// assignment operator
DataBuffer& DataBuffer::operator=(const DataBuffer& other) {
if (this == &other)
return *this;
deleteBuffers();
_lenInBytes = other._lenInBytes;
_dataType = other._dataType;
_workspace = other._workspace;
allocateBuffers();
copyBufferFrom(other);
return *this;
}
////////////////////////////////////////////////////////////////////////
// move assignment operator
DataBuffer& DataBuffer::operator=(DataBuffer&& other) noexcept {
if (this == &other)
return *this;
deleteBuffers();
_primaryBuffer = other._primaryBuffer;
_specialBuffer = other._specialBuffer;
_lenInBytes = other._lenInBytes;
_dataType = other._dataType;
_workspace = other._workspace;
_isOwnerPrimary = other._isOwnerPrimary;
_isOwnerSpecial = other._isOwnerSpecial;
copyCounters(other);
other._primaryBuffer = other._specialBuffer = nullptr;
other.setAllocFlags(false, false);
other._lenInBytes = 0;
return *this;
}
////////////////////////////////////////////////////////////////////////
void* DataBuffer::primary() {
return _primaryBuffer;
}
////////////////////////////////////////////////////////////////////////
void* DataBuffer::special() {
return _specialBuffer;
}
////////////////////////////////////////////////////////////////////////
DataType DataBuffer::getDataType() {
return _dataType;
}
////////////////////////////////////////////////////////////////////////
size_t DataBuffer::getLenInBytes() const {
return _lenInBytes;
}
////////////////////////////////////////////////////////////////////////
template <typename T>
T* DataBuffer::primaryAsT() {
template <typename T>
T* DataBuffer::primaryAsT() {
return reinterpret_cast<T*>(_primaryBuffer);
}
}
////////////////////////////////////////////////////////////////////////
template <typename T>
T* DataBuffer::specialAsT() {
template <typename T>
T* DataBuffer::specialAsT() {
return reinterpret_cast<T*>(_specialBuffer);
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::allocatePrimary() {
if (_primaryBuffer == nullptr && getLenInBytes() > 0) {
ALLOCATE(_primaryBuffer, _workspace, getLenInBytes(), int8_t);
_isOwnerPrimary = true;
}
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::setAllocFlags(const bool isOwnerPrimary, const bool isOwnerSpecial) {
_isOwnerPrimary = isOwnerPrimary;
_isOwnerSpecial = isOwnerSpecial;
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::deletePrimary() {
if(_isOwnerPrimary && _primaryBuffer != nullptr && getLenInBytes() != 0) {
auto p = reinterpret_cast<int8_t*>(_primaryBuffer);
RELEASE(p, _workspace);
_primaryBuffer = nullptr;
_isOwnerPrimary = false;
}
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::deleteBuffers() {
deletePrimary();
deleteSpecial();
_lenInBytes = 0;
}
////////////////////////////////////////////////////////////////////////
DataBuffer::~DataBuffer() {
deleteBuffers();
}
}

View File

@ -42,6 +42,8 @@ namespace nd4j {
QINT16 = 16,
BFLOAT16 = 17,
UTF8 = 50,
UTF16 = 51,
UTF32 = 52,
ANY = 100,
AUTO = 200,
};

View File

@ -0,0 +1,71 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <dll.h>
#include <array/DataBuffer.h>
#include <array/DataType.h>
#include <memory>
#ifndef LIBND4J_INTEROPDATABUFFER_H
#define LIBND4J_INTEROPDATABUFFER_H
namespace nd4j {
/**
* This class is a wrapper for DataBuffer, suitable for sharing DataBuffer between front-end and back-end languages
*/
class ND4J_EXPORT InteropDataBuffer {
private:
std::shared_ptr<DataBuffer> _dataBuffer;
uint64_t _offset = 0;
public:
InteropDataBuffer(InteropDataBuffer &dataBuffer, uint64_t length, uint64_t offset);
InteropDataBuffer(std::shared_ptr<DataBuffer> databuffer);
InteropDataBuffer(size_t elements, nd4j::DataType dtype, bool allocateBoth);
~InteropDataBuffer() = default;
#ifndef __JAVACPP_HACK__
std::shared_ptr<DataBuffer> getDataBuffer() const;
std::shared_ptr<DataBuffer> dataBuffer();
#endif
void* primary() const;
void* special() const;
uint64_t offset() const ;
void setOffset(uint64_t offset);
void setPrimary(void* ptr, size_t length);
void setSpecial(void* ptr, size_t length);
void expand(size_t newlength);
int deviceId() const;
void setDeviceId(int deviceId);
static void registerSpecialUse(const std::vector<const InteropDataBuffer*>& writeList, const std::vector<const InteropDataBuffer*>& readList);
static void prepareSpecialUse(const std::vector<const InteropDataBuffer*>& writeList, const std::vector<const InteropDataBuffer*>& readList, bool synchronizeWritables = false);
static void registerPrimaryUse(const std::vector<const InteropDataBuffer*>& writeList, const std::vector<const InteropDataBuffer*>& readList);
static void preparePrimaryUse(const std::vector<const InteropDataBuffer*>& writeList, const std::vector<const InteropDataBuffer*>& readList, bool synchronizeWritables = false);
};
}
#endif //LIBND4J_INTEROPDATABUFFER_H

View File

@ -23,6 +23,24 @@
#include <DataTypeUtils.h>
namespace nd4j {
void DataBuffer::expand(const uint64_t size) {
if (size > _lenInBytes) {
// allocate new buffer
int8_t *newBuffer = nullptr;
ALLOCATE(newBuffer, _workspace, size, int8_t);
// copy data from existing buffer
std::memcpy(newBuffer, _primaryBuffer, _lenInBytes);
if (_isOwnerPrimary) {
RELEASE(reinterpret_cast<int8_t *>(_primaryBuffer), _workspace);
}
_primaryBuffer = newBuffer;
_lenInBytes = size;
_isOwnerPrimary = true;
}
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::setCountersToZero() {
@ -99,14 +117,17 @@ void DataBuffer::allocateSpecial() {
void DataBuffer::migrate() {
}
///////////////////////////////////////////////////////////////////////
void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) {
if (src._lenInBytes < dst._lenInBytes)
throw std::runtime_error("DataBuffer::memcpy: Source data buffer is smaller than destination");
std::memcpy(dst._primaryBuffer, src._primaryBuffer, dst._lenInBytes);
/////////////////////////
void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) {
if (src._lenInBytes > dst._lenInBytes)
throw std::runtime_error("DataBuffer::memcpy: Source data buffer is larger than destination");
std::memcpy(dst._primaryBuffer, src._primaryBuffer, src._lenInBytes);
dst.readPrimary();
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::writePrimary() const { }
void DataBuffer::writeSpecial() const { }

View File

@ -25,6 +25,40 @@
#include <exceptions/cuda_exception.h>
namespace nd4j {
void DataBuffer::expand(const uint64_t size) {
if (size > _lenInBytes) {
// allocate new buffer
int8_t *newBuffer = nullptr;
int8_t *newSpecialBuffer = nullptr;
ALLOCATE_SPECIAL(newSpecialBuffer, _workspace, size, int8_t);
// copy data from existing buffer
if (_primaryBuffer != nullptr) {
// there's non-zero chance that primary buffer doesn't exist yet
ALLOCATE(newBuffer, _workspace, size, int8_t);
std::memcpy(newBuffer, _primaryBuffer, _lenInBytes);
if (_isOwnerPrimary) {
auto ipb = reinterpret_cast<int8_t *>(_primaryBuffer);
RELEASE(ipb, _workspace);
}
_primaryBuffer = newBuffer;
_isOwnerPrimary = true;
}
cudaMemcpy(newSpecialBuffer, _specialBuffer, _lenInBytes, cudaMemcpyDeviceToDevice);
if (_isOwnerSpecial) {
auto isb = reinterpret_cast<int8_t *>(_specialBuffer);
RELEASE_SPECIAL(isb, _workspace);
}
_specialBuffer = newSpecialBuffer;
_lenInBytes = size;
_isOwnerSpecial = true;
}
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::allocateSpecial() {
@ -37,8 +71,9 @@ void DataBuffer::allocateSpecial() {
////////////////////////////////////////////////////////////////////////
void DataBuffer::syncToPrimary(const LaunchContext* context, const bool forceSync) {
if(isPrimaryActual() && !forceSync)
if(isPrimaryActual() && !forceSync) {
return;
}
allocatePrimary();
@ -46,7 +81,9 @@ void DataBuffer::syncToPrimary(const LaunchContext* context, const bool forceSyn
if (res != 0)
throw cuda_exception::build("DataBuffer::syncToPrimary failed to to some previous kernel failre", res);
cudaMemcpy(_primaryBuffer, _specialBuffer, getLenInBytes(), cudaMemcpyDeviceToHost);
res = cudaMemcpy(_primaryBuffer, _specialBuffer, getLenInBytes(), cudaMemcpyDeviceToHost);
if (res != 0)
throw cuda_exception::build("DataBuffer::syncToPrimary cudaMemcpy failed", res);
readPrimary();
}
@ -54,13 +91,19 @@ void DataBuffer::syncToPrimary(const LaunchContext* context, const bool forceSyn
////////////////////////////////////////////////////////////////////////
void DataBuffer::syncToSpecial(const bool forceSync) {
if(isSpecialActual() && !forceSync)
// in this case there's nothing to do here
if (_primaryBuffer == nullptr)
return;
if(isSpecialActual() && !forceSync) {
return;
}
allocateSpecial();
cudaMemcpy(_specialBuffer, _primaryBuffer, getLenInBytes(), cudaMemcpyHostToDevice);
auto res = cudaMemcpy(_specialBuffer, _primaryBuffer, getLenInBytes(), cudaMemcpyHostToDevice);
if (res != 0)
throw cuda_exception::build("DataBuffer::syncToSpecial cudaMemcpy failed", res);
readSpecial();
}
@ -97,19 +140,6 @@ void DataBuffer::copyCounters(const DataBuffer& other) {
_readPrimary.store(other._writeSpecial);
_readSpecial.store(other._writePrimary);
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) {
if (src._lenInBytes < dst._lenInBytes)
throw std::runtime_error("DataBuffer::memcpy: Source data buffer is smaller than destination");
if (src.isSpecialActual()) {
cudaMemcpy(dst._specialBuffer, src._specialBuffer, dst.getLenInBytes(), cudaMemcpyDeviceToDevice);
} else if (src.isPrimaryActual()) {
cudaMemcpy(dst._specialBuffer, src._primaryBuffer, dst.getLenInBytes(), cudaMemcpyHostToDevice);
}
dst.writeSpecial();
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes, const Nd4jLong offsetThis, const Nd4jLong offsetOther) { // copies only to special buffer
@ -176,8 +206,11 @@ void DataBuffer::allocateBuffers(const bool allocBoth) { // always allocate s
////////////////////////////////////////////////////////////////////////
void DataBuffer::setToZeroBuffers(const bool both) {
cudaMemsetAsync(special(), 0, getLenInBytes(), *LaunchContext::defaultContext()->getCudaStream());
auto res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaStream());
if (res != 0)
throw cuda_exception::build("DataBuffer::setToZeroBuffers: streamSync failed!", res);
cudaMemset(special(), 0, getLenInBytes());
writeSpecial();
if(both) {
@ -186,12 +219,37 @@ void DataBuffer::setToZeroBuffers(const bool both) {
}
}
/////////////////////////
void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) {
if (src._lenInBytes > dst._lenInBytes)
throw std::runtime_error("DataBuffer::memcpy: Source data buffer is larger than destination");
int res = 0;
if (src.isSpecialActual()) {
res = cudaMemcpyAsync(dst._specialBuffer, src._specialBuffer, src.getLenInBytes(), cudaMemcpyDeviceToDevice, *LaunchContext::defaultContext()->getCudaStream());
} else if (src.isPrimaryActual()) {
res = cudaMemcpyAsync(dst._specialBuffer, src._primaryBuffer, src.getLenInBytes(), cudaMemcpyHostToDevice, *LaunchContext::defaultContext()->getCudaStream());
}
if (res != 0)
throw cuda_exception::build("DataBuffer::memcpy: cudaMemcpyAsync failed!", res);
res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaStream());
if (res != 0)
throw cuda_exception::build("DataBuffer::memcpy: streamSync failed!", res);
dst.writeSpecial();
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::migrate() {
memory::Workspace* newWorkspace = nullptr;
void* newBuffer;
ALLOCATE_SPECIAL(newBuffer, newWorkspace, getLenInBytes(), int8_t);
cudaMemcpy(newBuffer, _specialBuffer, getLenInBytes(), cudaMemcpyDeviceToDevice);
auto res = cudaMemcpy(newBuffer, _specialBuffer, getLenInBytes(), cudaMemcpyDeviceToDevice);
if (res != 0)
throw cuda_exception::build("DataBuffer::migrate: cudaMemcpyAsync failed!", res);
if (_isOwnerSpecial) {
// now we're releasing original buffer
@ -203,7 +261,7 @@ void DataBuffer::migrate() {
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::writePrimary() const { _writePrimary = ++_counter; }
void DataBuffer::writePrimary() const {_writePrimary = ++_counter; }
void DataBuffer::writeSpecial() const { _writeSpecial = ++_counter; }
void DataBuffer::readPrimary() const { _readPrimary = ++_counter; }
void DataBuffer::readSpecial() const { _readSpecial = ++_counter; }

View File

@ -0,0 +1,301 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include <array/DataBuffer.h>
#include <helpers/logger.h>
#include <array/DataTypeUtils.h>
#include <execution/AffinityManager.h>
namespace nd4j {
///// IMLEMENTATION OF COMMON METHODS /////
////////////////////////////////////////////////////////////////////////
// default constructor
DataBuffer::DataBuffer() {
_primaryBuffer = nullptr;
_specialBuffer = nullptr;
_lenInBytes = 0;
_dataType = INT8;
_workspace = nullptr;
_isOwnerPrimary = false;
_isOwnerSpecial = false;
_deviceId = nd4j::AffinityManager::currentDeviceId();
setCountersToZero();
}
////////////////////////////////////////////////////////////////////////
// copy constructor
DataBuffer::DataBuffer(const DataBuffer &other) {
throw std::runtime_error("DataBuffer copy constructor: we don't expect using of this constructor!");
_lenInBytes = other._lenInBytes;
_dataType = other._dataType;
_workspace = other._workspace;
_primaryBuffer = nullptr;
_specialBuffer = nullptr;
_deviceId.store(other._deviceId.load());
setCountersToZero();
allocateBuffers();
copyBufferFrom(other);
}
////////////////////////////////////////////////////////////////////////
DataBuffer::DataBuffer(void* primary, void* special,
const size_t lenInBytes, const DataType dataType,
const bool isOwnerPrimary, const bool isOwnerSpecial,
memory::Workspace* workspace) {
if (primary == nullptr && special == nullptr)
throw std::runtime_error("DataBuffer constructor: can't be initialized with both nullptr buffers !");
_primaryBuffer = primary;
_specialBuffer = special;
_lenInBytes = lenInBytes;
_dataType = dataType;
_workspace = workspace;
_isOwnerPrimary = isOwnerPrimary;
_isOwnerSpecial = isOwnerSpecial;
_deviceId = nd4j::AffinityManager::currentDeviceId();
setCountersToZero();
if(primary != nullptr)
readPrimary();
if(special != nullptr)
readSpecial();
}
////////////////////////////////////////////////////////////////////////
DataBuffer::DataBuffer(void* primary, const size_t lenInBytes, const DataType dataType, const bool isOwnerPrimary, memory::Workspace* workspace):
DataBuffer(primary, nullptr, lenInBytes, dataType, isOwnerPrimary, false, workspace) {
syncToSpecial(true);
}
////////////////////////////////////////////////////////////////////////
// copies data from hostBuffer to own memory buffer
DataBuffer::DataBuffer(const void* hostBuffer, const DataType dataType, const size_t lenInBytes, memory::Workspace* workspace) {
if (hostBuffer == nullptr)
throw std::runtime_error("DataBuffer constructor: can't be initialized with nullptr host buffer !");
if (lenInBytes == 0)
throw std::runtime_error("DataBuffer constructor: can't be initialized with zero length !");
_primaryBuffer = nullptr;
_specialBuffer = nullptr;
_lenInBytes = lenInBytes;
_dataType = dataType;
_workspace = workspace;
_deviceId = nd4j::AffinityManager::currentDeviceId();
setCountersToZero();
allocateBuffers();
copyBufferFromHost(hostBuffer, lenInBytes);
}
////////////////////////////////////////////////////////////////////////
DataBuffer::DataBuffer(const size_t lenInBytes, const DataType dataType, memory::Workspace* workspace, const bool allocBoth) {
_dataType = dataType;
_workspace = workspace;
_lenInBytes = lenInBytes;
_primaryBuffer = nullptr;
_specialBuffer = nullptr;
_deviceId = nd4j::AffinityManager::currentDeviceId();
setCountersToZero();
if(lenInBytes != 0) {
allocateBuffers(allocBoth);
writeSpecial();
}
}
////////////////////////////////////////////////////////////////////////
// move constructor
DataBuffer::DataBuffer(DataBuffer&& other) {
_primaryBuffer = other._primaryBuffer;
_specialBuffer = other._specialBuffer;
_lenInBytes = other._lenInBytes;
_dataType = other._dataType;
_workspace = other._workspace;
_isOwnerPrimary = other._isOwnerPrimary;
_isOwnerSpecial = other._isOwnerSpecial;
_deviceId.store(other._deviceId);
copyCounters(other);
other._primaryBuffer = other._specialBuffer = nullptr;
other.setAllocFlags(false, false);
other._lenInBytes = 0;
}
////////////////////////////////////////////////////////////////////////
// assignment operator
DataBuffer& DataBuffer::operator=(const DataBuffer& other) {
if (this == &other)
return *this;
deleteBuffers();
_lenInBytes = other._lenInBytes;
_dataType = other._dataType;
_workspace = other._workspace;
allocateBuffers();
copyBufferFrom(other);
return *this;
}
////////////////////////////////////////////////////////////////////////
// move assignment operator
DataBuffer& DataBuffer::operator=(DataBuffer&& other) noexcept {
if (this == &other)
return *this;
deleteBuffers();
_primaryBuffer = other._primaryBuffer;
_specialBuffer = other._specialBuffer;
_lenInBytes = other._lenInBytes;
_dataType = other._dataType;
_workspace = other._workspace;
_isOwnerPrimary = other._isOwnerPrimary;
_isOwnerSpecial = other._isOwnerSpecial;
copyCounters(other);
other._primaryBuffer = other._specialBuffer = nullptr;
other.setAllocFlags(false, false);
other._lenInBytes = 0;
return *this;
}
////////////////////////////////////////////////////////////////////////
void* DataBuffer::primary() {
return _primaryBuffer;
}
////////////////////////////////////////////////////////////////////////
void* DataBuffer::special() {
return _specialBuffer;
}
////////////////////////////////////////////////////////////////////////
DataType DataBuffer::getDataType() {
return _dataType;
}
////////////////////////////////////////////////////////////////////////
size_t DataBuffer::getLenInBytes() const {
return _lenInBytes;
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::allocatePrimary() {
if (_primaryBuffer == nullptr && getLenInBytes() > 0) {
ALLOCATE(_primaryBuffer, _workspace, getLenInBytes(), int8_t);
_isOwnerPrimary = true;
}
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::setAllocFlags(const bool isOwnerPrimary, const bool isOwnerSpecial) {
_isOwnerPrimary = isOwnerPrimary;
_isOwnerSpecial = isOwnerSpecial;
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::deletePrimary() {
if(_isOwnerPrimary && _primaryBuffer != nullptr && getLenInBytes() != 0) {
auto p = reinterpret_cast<int8_t*>(_primaryBuffer);
RELEASE(p, _workspace);
_primaryBuffer = nullptr;
_isOwnerPrimary = false;
}
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::deleteBuffers() {
deletePrimary();
deleteSpecial();
_lenInBytes = 0;
}
////////////////////////////////////////////////////////////////////////
DataBuffer::~DataBuffer() {
deleteBuffers();
}
void DataBuffer::setPrimaryBuffer(void *buffer, size_t length) {
if (_primaryBuffer != nullptr && _isOwnerPrimary) {
deletePrimary();
}
_primaryBuffer = buffer;
_isOwnerPrimary = false;
_lenInBytes = length * DataTypeUtils::sizeOf(_dataType);
}
void DataBuffer::setSpecialBuffer(void *buffer, size_t length) {
this->setSpecial(buffer, false);
_lenInBytes = length * DataTypeUtils::sizeOf(_dataType);
}
void DataBuffer::setDataType(DataType dataType) {
_dataType = dataType;
}
int DataBuffer::deviceId() const {
return _deviceId.load();
}
void DataBuffer::close() {
this->deleteBuffers();
}
void DataBuffer::setDeviceId(int deviceId) {
_deviceId = deviceId;
}
}

View File

@ -0,0 +1,146 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <array/InteropDataBuffer.h>
#include <array/DataTypeUtils.h>
#include <execution/AffinityManager.h>
#include <helpers/logger.h>
namespace nd4j {
InteropDataBuffer::InteropDataBuffer(InteropDataBuffer &dataBuffer, uint64_t length, uint64_t offset) {
_dataBuffer = dataBuffer.getDataBuffer();
// offset is always absolute to the original buffer
_offset = offset;
if (_offset + length > _dataBuffer->getLenInBytes()) {
throw std::runtime_error("offset + length is higher than original length");
}
}
InteropDataBuffer::InteropDataBuffer(std::shared_ptr<DataBuffer> databuffer) {
_dataBuffer = databuffer;
}
InteropDataBuffer::InteropDataBuffer(size_t elements, nd4j::DataType dtype, bool allocateBoth) {
if (elements == 0) {
_dataBuffer = std::make_shared<DataBuffer>();
_dataBuffer->setDataType(dtype);
} else {
_dataBuffer = std::make_shared<DataBuffer>(elements, dtype, nullptr, allocateBoth);
}
}
std::shared_ptr<DataBuffer> InteropDataBuffer::getDataBuffer() const {
return _dataBuffer;
}
std::shared_ptr<DataBuffer> InteropDataBuffer::dataBuffer() {
return _dataBuffer;
}
void* InteropDataBuffer::primary() const {
return reinterpret_cast<int8_t *>(_dataBuffer->primary()) + _offset;
}
void* InteropDataBuffer::special() const {
return reinterpret_cast<int8_t *>(_dataBuffer->special()) + _offset;
}
void InteropDataBuffer::setPrimary(void* ptr, size_t length) {
_dataBuffer->setPrimaryBuffer(ptr, length);
}
void InteropDataBuffer::setSpecial(void* ptr, size_t length) {
_dataBuffer->setSpecialBuffer(ptr, length);
}
uint64_t InteropDataBuffer::offset() const {
return _offset;
}
void InteropDataBuffer::setOffset(uint64_t offset) {
_offset = offset;
}
int InteropDataBuffer::deviceId() const {
return _dataBuffer->deviceId();
}
void InteropDataBuffer::registerSpecialUse(const std::vector<const InteropDataBuffer*>& writeList, const std::vector<const InteropDataBuffer*>& readList) {
for (const auto &v:writeList) {
if (v == nullptr)
continue;
v->getDataBuffer()->writeSpecial();
}
}
void InteropDataBuffer::prepareSpecialUse(const std::vector<const InteropDataBuffer*>& writeList, const std::vector<const InteropDataBuffer*>& readList, bool synchronizeWritables) {
auto currentDeviceId = nd4j::AffinityManager::currentDeviceId();
for (const auto &v:readList) {
if (v == nullptr)
continue;
if (v->getDataBuffer()->deviceId() != currentDeviceId)
v->getDataBuffer()->migrate();
v->getDataBuffer()->syncToSpecial();
}
// we don't tick write list, only ensure the same device affinity
for (const auto &v:writeList) {
if (v == nullptr)
continue;
// special case for legacy ops - views can be updated on host side, thus original array can be not updated
if (!v->getDataBuffer()->isSpecialActual())
v->getDataBuffer()->syncToSpecial();
if (v->getDataBuffer()->deviceId() != currentDeviceId)
v->getDataBuffer()->migrate();
}
}
void InteropDataBuffer::registerPrimaryUse(const std::vector<const InteropDataBuffer*>& writeList, const std::vector<const InteropDataBuffer*>& readList) {
for (const auto &v:writeList) {
if (v == nullptr)
continue;
}
}
void InteropDataBuffer::preparePrimaryUse(const std::vector<const InteropDataBuffer*>& writeList, const std::vector<const InteropDataBuffer*>& readList, bool synchronizeWritables) {
for (const auto &v:readList) {
if (v == nullptr)
continue;
v->getDataBuffer()->syncToPrimary(LaunchContext::defaultContext());
}
}
void InteropDataBuffer::expand(size_t newlength) {
_dataBuffer->expand(newlength * DataTypeUtils::sizeOf(_dataBuffer->getDataType()));
}
void InteropDataBuffer::setDeviceId(int deviceId) {
_dataBuffer->setDeviceId(deviceId);
}
}

View File

@ -138,7 +138,7 @@ namespace nd4j {
if (res != 0)
throw cuda_exception::build("_reductionPointer allocation failed", res);
res = cudaMalloc(reinterpret_cast<void**>(&_scalarPointer), 16);
res = cudaHostAlloc(reinterpret_cast<void**>(&_scalarPointer), 16, cudaHostAllocDefault);
if (res != 0)
throw cuda_exception::build("_scalarPointer allocation failed", res);

View File

@ -185,9 +185,11 @@ namespace nd4j {
void setInputArray(int index, NDArray *array, bool removable = false);
void setInputArray(int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
void setInputArray(int index, void *databuffer, void *shapeInfo, void *specialShapeInfo);
void setOutputArray(int index, NDArray *array, bool removable = false);
void setOutputArray(int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
void setOutputArray(int index, void *databuffer, void *shapeInfo, void *specialShapeInfo);
void setTArguments(double *arguments, int numberOfArguments);
void setIArguments(Nd4jLong *arguments, int numberOfArguments);

View File

@ -21,6 +21,7 @@
#include <Context.h>
#include <helpers/ShapeUtils.h>
#include <graph/Context.h>
#include <array/InteropDataBuffer.h>
namespace nd4j {
@ -426,6 +427,44 @@ namespace nd4j {
array->setContext(_context);
}
void Context::setInputArray(int index, void *vdatabuffer, void *shapeInfo, void *specialShapeInfo) {
auto dataBuffer = reinterpret_cast<InteropDataBuffer*>(vdatabuffer);
if (_fastpath_in.size() < index + 1)
_fastpath_in.resize(index+1);
NDArray *array;
if (dataBuffer != nullptr)
array = new NDArray(dataBuffer->dataBuffer(), reinterpret_cast<Nd4jLong *>(shapeInfo), nd4j::LaunchContext::defaultContext(), dataBuffer->offset() / DataTypeUtils::sizeOf(ArrayOptions::dataType(reinterpret_cast<Nd4jLong *>(shapeInfo))));
else
array = new NDArray(nullptr, nullptr, reinterpret_cast<Nd4jLong *>(shapeInfo));
_fastpath_in[index] = array;
_handles.emplace_back(array);
if (_context != nullptr)
array->setContext(_context);
}
void Context::setOutputArray(int index, void *vdatabuffer, void *shapeInfo, void *specialShapeInfo) {
auto dataBuffer = reinterpret_cast<InteropDataBuffer*>(vdatabuffer);
if (_fastpath_out.size() < index + 1)
_fastpath_out.resize(index+1);
NDArray *array;
if (dataBuffer != nullptr)
array = new NDArray(dataBuffer->dataBuffer(), reinterpret_cast<Nd4jLong *>(shapeInfo), nd4j::LaunchContext::defaultContext(), dataBuffer->offset() / DataTypeUtils::sizeOf(ArrayOptions::dataType(reinterpret_cast<Nd4jLong *>(shapeInfo))));
else
array = new NDArray(nullptr, nullptr, reinterpret_cast<Nd4jLong *>(shapeInfo));
_fastpath_out[index] = array;
_handles.emplace_back(array);
if (_context != nullptr)
array->setContext(_context);
}
void Context::setTArguments(double *arguments, int numberOfArguments) {
_tArgs.clear();
_tArgs.reserve(numberOfArguments);

View File

@ -43,6 +43,8 @@ enum DType:byte {
QINT16,
BFLOAT16 = 17,
UTF8 = 50,
UTF16 = 51,
UTF32 = 52,
}
// this structure describe NDArray

View File

@ -34,8 +34,6 @@
#include <driver_types.h>
#include <cuda_runtime_api.h>
#define checkCudaErrors(ERR) if (ERR != 0) {throw std::runtime_error("CUDA stream synchronization failed");}
#endif
#include <DebugInfo.h>
namespace nd4j {

View File

@ -25,6 +25,8 @@
#include <op_boilerplate.h>
#include <string>
#include <sstream>
#include <vector>
#include <NDArray.h>
namespace nd4j {
class ND4J_EXPORT StringUtils {
@ -53,6 +55,36 @@ namespace nd4j {
return result;
}
/**
* This method returns number of needle matches within haystack
* PLEASE NOTE: this method operates on 8-bit arrays interpreted as uint8
*
* @param haystack
* @param haystackLength
* @param needle
* @param needleLength
* @return
*/
static uint64_t countSubarrays(const void *haystack, uint64_t haystackLength, const void *needle, uint64_t needleLength);
/**
* This method returns number of bytes used for string NDArrays content
* PLEASE NOTE: this doesn't include header
*
* @param array
* @return
*/
static uint64_t byteLength(const NDArray &array);
/**
* This method splits a string into substring by delimiter
*
* @param haystack
* @param delimiter
* @return
*/
static std::vector<std::string> split(const std::string &haystack, const std::string &delimiter);
};
}

View File

@ -19,7 +19,58 @@
//
#include <helpers/StringUtils.h>
#include <exceptions/datatype_exception.h>
namespace nd4j {
static FORCEINLINE bool match(const uint8_t *haystack, const uint8_t *needle, uint64_t length) {
for (int e = 0; e < length; e++)
if (haystack[e] != needle[e])
return false;
return true;
}
uint64_t StringUtils::countSubarrays(const void *vhaystack, uint64_t haystackLength, const void *vneedle, uint64_t needleLength) {
auto haystack = reinterpret_cast<const uint8_t*>(vhaystack);
auto needle = reinterpret_cast<const uint8_t*>(vneedle);
uint64_t number = 0;
for (uint64_t e = 0; e < haystackLength - needleLength; e++) {
if (match(&haystack[e], needle, needleLength))
number++;
}
return number;
}
uint64_t StringUtils::byteLength(const NDArray &array) {
if (!array.isS())
throw nd4j::datatype_exception::build("StringUtils::byteLength expects one of String types;", array.dataType());
uint64_t result = 0;
// our buffer stores offsets, and the last value is basically number of bytes used
auto buffer = array.bufferAsT<Nd4jLong>();
result = buffer[array.lengthOf()];
return result;
}
std::vector<std::string> StringUtils::split(const std::string &haystack, const std::string &delimiter) {
std::vector<std::string> output;
std::string::size_type prev_pos = 0, pos = 0;
// iterating through the haystack till the end
while((pos = haystack.find(delimiter, pos)) != std::string::npos) {
output.emplace_back(haystack.substr(prev_pos, pos-prev_pos));
prev_pos = ++pos;
}
output.emplace_back(haystack.substr(prev_pos, pos - prev_pos)); // Last word
return output;
}
}

View File

@ -20,7 +20,6 @@
//
#include <types/types.h>
#include <ShapeUtils.h>
#include <op_boilerplate.h>
#include <loops/reduce_bool.h>
#include <loops/legacy_ops.h>

View File

@ -20,7 +20,6 @@
//
#include <types/types.h>
#include <ShapeUtils.h>
#include <op_boilerplate.h>
#include <loops/reduce_float.h>
#include <loops/legacy_ops.h>

View File

@ -20,7 +20,6 @@
//
#include <types/types.h>
#include <ShapeUtils.h>
#include <op_boilerplate.h>
#include <loops/reduce_long.h>
#include <loops/legacy_ops.h>

View File

@ -20,7 +20,6 @@
//
#include <types/types.h>
#include <ShapeUtils.h>
#include <op_boilerplate.h>
#include <loops/reduce_same.h>
#include <loops/legacy_ops.h>

View File

@ -1624,4 +1624,9 @@
#define PARAMETRIC_D() [&] (Parameters &p) -> Context*
#ifdef __CUDABLAS__
#define checkCudaErrors(ERR) if (ERR != 0) {throw std::runtime_error("CUDA stream synchronization failed");}
#endif
#endif

View File

@ -40,6 +40,9 @@
#include <ops/declarable/headers/third_party.h>
#include <ops/declarable/headers/tests.h>
#include <ops/declarable/headers/kernels.h>
#include <ops/declarable/headers/strings.h>
#include <ops/declarable/headers/compat.h>
#include <ops/declarable/headers/util.h>
#include <ops/declarable/headers/BarnesHutTsne.h>
#include <ops/declarable/headers/images.h>
#include <dll.h>

View File

@ -0,0 +1 @@
This folder contains operations required for compatibility with TF and other frameworks.

View File

@ -0,0 +1,73 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_split_string)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/sparse_to_dense.h>
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(compat_sparse_to_dense, 4, 1, false, 0, 0) {
auto indices = INPUT_VARIABLE(0);
auto shape = INPUT_VARIABLE(1);
auto values = INPUT_VARIABLE(2);
NDArray *def = nullptr;
auto output = OUTPUT_VARIABLE(0);
if (block.width() > 3)
def = INPUT_VARIABLE(3);
nd4j::ops::helpers::compat_sparse_to_dense(*values, *indices, def, *output);
return Status::OK();
};
DECLARE_SHAPE_FN(compat_sparse_to_dense) {
auto indices = INPUT_VARIABLE(0);
auto shape = INPUT_VARIABLE(1);
auto values = INPUT_VARIABLE(2);
if (block.width() > 3) {
auto def = INPUT_VARIABLE(3);
REQUIRE_TRUE(def->dataType() == values->dataType() && def->isScalar(), 0, "compat_sparse_to_dense: default value must be a scalar of the same data type as actual values")
};
auto dtype = values->dataType();
// basically output shape is defined by the type of input, and desired shape input
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'c', shape->getBufferAsVector<Nd4jLong>()));
}
DECLARE_TYPES(compat_sparse_to_dense) {
getOpDescriptor()
->setAllowedInputTypes(0, {ALL_INTS}) // indices
->setAllowedInputTypes(1, {ALL_INTS}) // shape
->setAllowedInputTypes(2,nd4j::DataType::ANY) // sparse values
->setAllowedInputTypes(3,nd4j::DataType::ANY) // default value
->setAllowedOutputTypes(nd4j::DataType::ANY);
}
}
}
#endif

View File

@ -0,0 +1,140 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_split_string)
#include <ops/declarable/CustomOperations.h>
#include <helpers/StringUtils.h>
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(compat_string_split, 2, 2, false, 0, 0) {
auto input = INPUT_VARIABLE(0);
auto delim = INPUT_VARIABLE(1);
auto indices = OUTPUT_VARIABLE(0);
auto values = OUTPUT_VARIABLE(1);
auto d = delim->e<std::string>(0);
input->syncToHost();
delim->syncToHost();
// output rank N+1 wrt input rank
std::vector<Nd4jLong> ocoords(input->rankOf() + 1);
std::vector<Nd4jLong> icoords(input->rankOf());
// getting buffer lengths
// FIXME: it'll be bigger, since it'll include delimiters,
auto outputLength = StringUtils::byteLength(*input);
uint64_t ss = 0L;
Nd4jLong ic = 0L;
// loop through each string within tensor
for (auto e = 0L; e < input->lengthOf(); e++) {
// now we should map substring to indices
auto s = input->e<std::string>(e);
// getting base index
shape::index2coords(e, input->shapeInfo(), icoords.data());
// getting number of substrings
auto cnt = StringUtils::countSubarrays(s.c_str(), s.length(), d.c_str(), d.length()) + 1;
// filling output indices
for (uint64_t f = 0; f < cnt; f++) {
for (auto v: icoords)
indices->p(ic++, v);
// last index
indices->p(ic++, f);
}
ss += cnt;
}
// process strings now
std::vector<std::string> strings;
for (auto e = 0L; e < input->lengthOf(); e++) {
auto split = StringUtils::split(input->e<std::string>(e), d);
for (const auto &s:split)
strings.emplace_back(s);
}
// now once we have all strings in single vector time to fill
auto tmp = NDArrayFactory::string('c', {(Nd4jLong) strings.size()}, strings);
auto blen = StringUtils::byteLength(tmp) + ShapeUtils::stringBufferHeaderRequirements(strings.size());
// for CUDA mostly
values->dataBuffer()->allocatePrimary();
values->dataBuffer()->expand(blen);
memcpy(values->buffer(), tmp.buffer(), blen);
values->tickWriteHost();
// special case, for future use
indices->syncToDevice();
values->syncToDevice();
// we have to tick buffers
values->dataBuffer()->writePrimary();
values->dataBuffer()->readSpecial();
return Status::OK();
};
DECLARE_SHAPE_FN(compat_string_split) {
auto input = INPUT_VARIABLE(0);
auto delim = INPUT_VARIABLE(1);
auto d = delim->e<std::string>(0);
// count number of delimiter substrings in all strings within input tensor
uint64_t cnt = 0;
for (auto e = 0L; e < input->lengthOf(); e++) {
// FIXME: bad, not UTF-compatible
auto s = input->e<std::string>(e);
// each substring we see in haystack, splits string in two parts. so we should add 1 to the number of subarrays
cnt += StringUtils::countSubarrays(s.c_str(), s.length(), d.c_str(), d.length()) + 1;
}
// shape calculations
// virtual tensor rank will be N+1, for N rank input array, where data will be located at the biggest dimension
// values tensor is going to be vector always
// indices tensor is going to be vector with length equal to values.length * output rank
auto valuesShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(cnt, nd4j::DataType::UTF8);
auto indicesShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(cnt * (input->rankOf() + 1), nd4j::DataType::INT64);
return SHAPELIST(indicesShape, valuesShape);
}
DECLARE_TYPES(compat_string_split) {
getOpDescriptor()
->setAllowedInputTypes({ALL_STRINGS})
->setAllowedOutputTypes(0, {ALL_INDICES})
->setAllowedOutputTypes(1, {ALL_STRINGS});
}
}
}
#endif

View File

@ -47,8 +47,7 @@ namespace nd4j {
}
// just memcpy data
// output->dataBuffer()->copyBufferFrom(*input->dataBuffer()); // as variant
DataBuffer::memcpy(*output->dataBuffer(), *input->dataBuffer()); // this is modern approach
DataBuffer::memcpy(*output->dataBuffer(), *input->dataBuffer());
return Status::OK();
}

View File

@ -0,0 +1,50 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_split_string)
#include <ops/declarable/CustomOperations.h>
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(split_string, 2, 1, true, 0, 0) {
auto input = INPUT_VARIABLE(0);
auto delim = INPUT_VARIABLE(1);
return Status::OK();
};
DECLARE_SHAPE_FN(split_string) {
auto input = INPUT_VARIABLE(0);
auto delim = INPUT_VARIABLE(1);
return SHAPELIST();
}
DECLARE_TYPES(split_string) {
getOpDescriptor()
->setAllowedInputTypes({ALL_STRINGS})
->setAllowedOutputTypes({ALL_STRINGS});
}
}
}
#endif

View File

@ -0,0 +1,52 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_print_affinity)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/print_variable.h>
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(print_affinity, 1, 1, true, 0, 0) {
// TODO: make this op compatible with ArrayList etc
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
nd4j_printf("<Node %i>: Actuality: [HOST: %s; DEVICE: %s]; affinity: [%i]; Pointers: [HOST: %p; DEVICE: %p]; DataBuffer length: %lld\n", block.nodeId(), input->isActualOnHostSide() ? "true" : "false", input->isActualOnDeviceSide() ? "true" : "false", input->dataBuffer()->deviceId(), input->getBuffer(), input->getSpecialBuffer(), input->dataBuffer()->getLenInBytes());
return Status::OK();
}
DECLARE_TYPES(print_affinity) {
getOpDescriptor()
->setAllowedInputTypes(0, nd4j::DataType::ANY)
->setAllowedInputTypes(1, {ALL_STRINGS})
->setAllowedOutputTypes(0, nd4j::DataType::INT32);
}
DECLARE_SHAPE_FN(print_affinity) {
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT32));
}
}
}
#endif

View File

@ -0,0 +1,77 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_print_variable)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/print_variable.h>
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(print_variable, 1, 1, true, 0, 0) {
// TODO: make this op compatible with ArrayList etc
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
std::string str;
if (block.width() == 2) {
auto message = INPUT_VARIABLE(1);
REQUIRE_TRUE(message->isS(), 0, "print_variable: message variable must be a String");
str = message->e<std::string>(0);
}
bool printSpecial = false;
if (block.numB() > 0)
printSpecial = B_ARG(0);
if (printSpecial && !nd4j::Environment::getInstance()->isCPU()) {
// only specific backends support special printout. for cpu-based backends it's the same as regular print
if (block.width() == 2)
helpers::print_special(*block.launchContext(), *input, str);
else
helpers::print_special(*block.launchContext(), *input);
} else {
// optionally add message to the print out
if (block.width() == 2) {
input->printIndexedBuffer(str.c_str());
} else {
input->printIndexedBuffer();
}
}
return Status::OK();
}
DECLARE_TYPES(print_variable) {
getOpDescriptor()
->setAllowedInputTypes(0, nd4j::DataType::ANY)
->setAllowedInputTypes(1, {ALL_STRINGS})
->setAllowedOutputTypes(0, nd4j::DataType::INT32);
}
DECLARE_SHAPE_FN(print_variable) {
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT32));
}
}
}
#endif

View File

@ -0,0 +1,54 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef SAMEDIFF_COMPAT_H
#define SAMEDIFF_COMPAT_H
#include <ops/declarable/headers/common.h>
namespace nd4j {
namespace ops {
/**
* This operation splits input string into pieces separated by delimiter
* PLEASE NOTE: This implementation is compatible with TF 1.x
*
* Input[0] - string to split
* Input[1] - delimiter
*
* Returns:
* Output[0] - indices tensor
* Output[1] - values tensor
*/
#if NOT_EXCLUDED(OP_compat_string_split)
DECLARE_CUSTOM_OP(compat_string_split, 2, 2, false, 0, 0);
#endif
/**
* This operation converts TF sparse array representation to dense NDArray
*/
#if NOT_EXCLUDED(OP_compat_sparse_to_dense)
DECLARE_CUSTOM_OP(compat_sparse_to_dense, 4, 1, false, 0, 0);
#endif
}
}
#endif //SAMEDIFF_COMPAT_H

View File

@ -0,0 +1,42 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef SAMEDIFF_STRINGS_H
#define SAMEDIFF_STRINGS_H
#include <ops/declarable/headers/common.h>
namespace nd4j {
namespace ops {
/**
* This operation splits input string into pieces separated by delimiter
*
* Input[0] - string to split
* Input[1] - delimiter
*/
#if NOT_EXCLUDED(OP_split_string)
DECLARE_CUSTOM_OP(split_string, 2, 1, true, 0, 0);
#endif
}
}
#endif //SAMEDIFF_STRINGS_H

View File

@ -0,0 +1,44 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef LIBND4J_UTILS_H
#define LIBND4J_UTILS_H
#include <ops/declarable/headers/common.h>
namespace nd4j {
namespace ops {
/**
* This operation prints out NDArray content, either on host or device.
*/
#if NOT_EXCLUDED(OP_print_variable)
DECLARE_CUSTOM_OP(print_variable, 1, 1, true, 0, 0);
#endif
/**
* This operation prints out affinity & locality status of given NDArray
*/
#if NOT_EXCLUDED(OP_print_affinity)
DECLARE_CUSTOM_OP(print_affinity, 1, 1, true, 0, 0);
#endif
}
}
#endif //LIBND4J_UTILS_H

View File

@ -0,0 +1,31 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <ops/declarable/helpers/print_variable.h>
namespace nd4j {
namespace ops {
namespace helpers {
void print_special(LaunchContext &ctx, const NDArray &array, const std::string &message) {
array.printIndexedBuffer(message.c_str());
}
}
}
}

View File

@ -40,15 +40,11 @@ __global__ void preluCuda(const void *vx, const Nd4jLong *xShapeInfo,
const auto y = reinterpret_cast<const Y*>(vy);
auto z = reinterpret_cast<X*>(vz);
__shared__ Nd4jLong xzLen, totalThreads, *sharedMem;
__shared__ Nd4jLong xzLen;
__shared__ int xzRank, yRank;
if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
xzLen = shape::length(xShapeInfo);
totalThreads = gridDim.x * blockDim.x;
xzRank = shape::rank(xShapeInfo);
yRank = shape::rank(yShapeInfo);
@ -56,18 +52,15 @@ __global__ void preluCuda(const void *vx, const Nd4jLong *xShapeInfo,
__syncthreads();
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
Nd4jLong* coords = sharedMem + threadIdx.x * xzRank;
for (int i = tid; i < xzLen; i += totalThreads) {
Nd4jLong coords[MAX_RANK];
for (int i = tid; i < xzLen; i += blockDim.x * gridDim.x) {
shape::index2coords(i, xShapeInfo, coords);
const auto xzOffset = shape::getOffset(xShapeInfo, coords);
const auto xVal = x[xzOffset];
if(xVal < 0) {
for (uint j = 0; j < yRank; ++j)
if(yShapeInfo[j + 1] == 1)
coords[j + 1] = 0;
@ -82,7 +75,6 @@ __global__ void preluCuda(const void *vx, const Nd4jLong *xShapeInfo,
///////////////////////////////////////////////////////////////////
template<typename X, typename Y>
linkage void preluCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz) {
preluCuda<X, Y><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz);
}
@ -91,9 +83,9 @@ void prelu(nd4j::LaunchContext * context, const NDArray& input, const NDArray& a
PointersManager manager(context, "prelu");
const int threadsPerBlock = MAX_NUM_THREADS / 2;
const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
const int sharedMem = input.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
const int threadsPerBlock = 256;
const int blocksPerGrid = 512;
const int sharedMem = 512;
const auto xType = input.dataType();
const auto yType = alpha.dataType();
@ -119,13 +111,10 @@ __global__ linkage void preluBPCuda(const void *vIn, const Nd4jLong *inShapeI
auto dLdI = reinterpret_cast<Y*>(vdLdI);
auto dLdA = reinterpret_cast<Y*>(vdLdA);
__shared__ Nd4jLong inLen, totalThreads, *sharedMem;
__shared__ Nd4jLong inLen, totalThreads;
__shared__ int inRank, alphaRank;
if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
inLen = shape::length(inShapeInfo);
totalThreads = gridDim.x * blockDim.x;
@ -135,10 +124,9 @@ __global__ linkage void preluBPCuda(const void *vIn, const Nd4jLong *inShapeI
__syncthreads();
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
Nd4jLong* coords = sharedMem + threadIdx.x * inRank;
Nd4jLong coords[MAX_RANK];
for (int i = tid; i < inLen; i += totalThreads) {
shape::index2coords(i, inShapeInfo, coords);
const auto inOffset = shape::getOffset(inShapeInfo, coords);
@ -175,14 +163,13 @@ __host__ linkage void preluBPCudaLauncher(const int blocksPerGrid, const int thr
//////////////////////////////////////////////////////////////////////////
void preluBP(nd4j::LaunchContext* context, const NDArray& input, const NDArray& alpha, const NDArray& dLdO, NDArray& dLdI, NDArray& dLdA) {
dLdA.nullify();
PointersManager manager(context, "preluBP");
const int threadsPerBlock = MAX_NUM_THREADS / 2;
const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
const int sharedMem = input.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
const int threadsPerBlock = 256;
const int blocksPerGrid = 512;
const int sharedMem = 512;
const auto xType = input.dataType();
const auto zType = alpha.dataType();

View File

@ -0,0 +1,61 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <ops/declarable/helpers/print_variable.h>
#include <helpers/PointersManager.h>
namespace nd4j {
namespace ops {
namespace helpers {
template <typename T>
static _CUDA_G void print_device(const void *special, const Nd4jLong *shapeInfo) {
auto length = shape::length(shapeInfo);
auto x = reinterpret_cast<const T*>(special);
// TODO: add formatting here
printf("[");
for (uint64_t e = 0; e < length; e++) {
printf("%f", (float) x[shape::getIndexOffset(e, shapeInfo)]);
if (e < length - 1)
printf(", ");
}
printf("]\n");
}
template <typename T>
static _CUDA_H void exec_print_device(LaunchContext &ctx, const void *special, const Nd4jLong *shapeInfo) {
print_device<T><<<1, 1, 1024, *ctx.getCudaStream()>>>(special, shapeInfo);
}
void print_special(LaunchContext &ctx, const NDArray &array, const std::string &message) {
NDArray::prepareSpecialUse({}, {&array});
PointersManager pm(&ctx, "print_device");
BUILD_SINGLE_SELECTOR(array.dataType(), exec_print_device, (ctx, array.getSpecialBuffer(), array.getSpecialShapeInfo()), LIBND4J_TYPES)
pm.synchronize();
NDArray::registerSpecialUse({}, {&array});
}
}
}
}

View File

@ -41,6 +41,9 @@
#include <helpers/DebugHelper.h>
#include <stdio.h>
#include <stdlib.h>
#include <DebugHelper.h>
#endif // CUDACC
#endif // LIBND4J_HELPERS_H

View File

@ -0,0 +1,123 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <ops/declarable/helpers/sparse_to_dense.h>
#include <helpers/StringUtils.h>
#include <helpers/ShapeUtils.h>
namespace nd4j {
namespace ops {
namespace helpers {
template <typename X, typename I>
static void fill_(const void *vvalues, const void *vindices, void *voutput, const Nd4jLong *zShapeInfo, uint8_t rank, uint64_t length) {
auto values = reinterpret_cast<const X*>(vvalues);
auto indices = reinterpret_cast<const I*>(vindices);
auto output = reinterpret_cast<X*>(voutput);
Nd4jLong coords[MAX_RANK];
uint64_t pos = 0;
for (uint64_t e = 0L; e < length; e++) {
// indices come in blocks
for (uint8_t p = 0; p < rank; p++) {
coords[p] = indices[pos++];
}
// fill output at given coords with sparse value
output[shape::getOffset(zShapeInfo, coords)] = values[e];
}
}
void compat_sparse_to_dense(const NDArray &values, const NDArray &indices, NDArray *def, NDArray &output) {
// make sure host buffer is updated
values.syncToHost();
indices.syncToHost();
auto rank = output.rankOf();
if (output.isS()) {
// string case is not so trivial, since elements might, and probably will, have different sizes
auto numValues = values.lengthOf();
auto numElements = output.lengthOf();
// first of all we calculate final buffer sizes and offsets
auto defaultLength = def == nullptr ? 0 : StringUtils::byteLength(*def);
auto valuesLength = StringUtils::byteLength(values);
auto bufferLength = defaultLength * (output.lengthOf() - numValues) + valuesLength;
auto headerLength = ShapeUtils::stringBufferHeaderRequirements(numElements);
// now we make sure our output buffer can hold results
output.dataBuffer()->expand( bufferLength + headerLength);
std::vector<Nd4jLong> outputCoords(rank);
std::vector<Nd4jLong> valueCoords(rank);
auto offsetsBuffer = output.bufferAsT<Nd4jLong>();
auto dataBuffer = reinterpret_cast<uint8_t*>(offsetsBuffer + output.lengthOf());
offsetsBuffer[0] = 0;
// getting initial value coords
for (int e = 0; e < rank; e++)
valueCoords[e] = indices.e<Nd4jLong>(e);
// write results individually
for (uint64_t e = 0; e < numElements; e++) {
auto vIndex = shape::coords2index(output.shapeInfo(), valueCoords.data());
auto cLength = 0L;
std::string str;
if (vIndex == e) {
// we're writing down sparse value here
str = values.e<std::string>(e);
} else {
// we're writing down default value if it exists
if (def != nullptr)
str = def->e<std::string>(0);
else
str = "";
}
// TODO: make it unicode compliant
memcpy(&dataBuffer[offsetsBuffer[e]], str.c_str(), str.length());
// writing down offset
offsetsBuffer[e+1] = cLength;
}
} else {
// numeric case is trivial, since all elements have equal sizes
// write out default values, if they are present
if (def != nullptr) {
output.assign(def);
// make sure output is synced back
output.syncToHost();
}
// write out values
BUILD_DOUBLE_SELECTOR(values.dataType(), indices.dataType(), fill_, (values.getBuffer(), indices.getBuffer(), output.buffer(), output.getShapeInfo(), rank, values.lengthOf()), LIBND4J_TYPES, INDEXING_TYPES);
}
// copy back to device, if there's any
output.syncToDevice();
}
}
}
}

View File

@ -0,0 +1,34 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef LIBND4J_PRINT_VARIABLE_H
#define LIBND4J_PRINT_VARIABLE_H
#include <ops/declarable/helpers/helpers.h>
namespace nd4j {
namespace ops {
namespace helpers {
void print_special(LaunchContext &ctx, const NDArray &array, const std::string &message = {});
}
}
}
#endif //LIBND4J_PRINT_VARIABLE_H

View File

@ -0,0 +1,34 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef SAMEDIFF_SPARSE_TO_DENSE_H
#define SAMEDIFF_SPARSE_TO_DENSE_H
#include <ops/declarable/helpers/helpers.h>
namespace nd4j {
namespace ops {
namespace helpers {
void compat_sparse_to_dense(const NDArray &values, const NDArray &indices, NDArray *def, NDArray &output);
}
}
}
#endif //SAMEDIFF_SPARSE_TO_DENSE_H

View File

@ -634,7 +634,7 @@
#define BROADCAST(NAME) nd4j::BroadcastOpsTuple::custom(nd4j::scalar::NAME, nd4j::pairwise::NAME, nd4j::broadcast::NAME)
#define BROADCAST_BOOL(NAME) nd4j::BroadcastBoolOpsTuple::custom(nd4j::scalar::NAME, nd4j::pairwise::NAME, nd4j::broadcast::NAME)
#define ALL_STRINGS nd4j::DataType::UTF8, nd4j::DataType::UTF16, nd4j::DataType::UTF32
#define ALL_INDICES nd4j::DataType::INT32, nd4j::DataType::INT64
#define ALL_INTS nd4j::DataType::INT8, nd4j::DataType::UINT8, nd4j::DataType::INT16, nd4j::DataType::UINT16, nd4j::DataType::INT32, nd4j::DataType::UINT32, nd4j::DataType::INT64, nd4j::DataType::UINT64
#define ALL_FLOATS nd4j::DataType::HALF, nd4j::DataType::FLOAT32, nd4j::DataType::DOUBLE, nd4j::DataType::BFLOAT16

View File

@ -810,9 +810,10 @@ TEST_F(DeclarableOpsTests12, pullRows_1) {
#ifdef __CUDABLAS__
nativeStart[1] = (x.getContext()->getCudaStream());
#endif
pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(),
z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer zBuf(z.dataBuffer());
pullRows(nativeStart, &xBuf, x.getShapeInfo(), x.getSpecialShapeInfo(),
&zBuf, z.getShapeInfo(), z.specialShapeInfo(),
4, pidx,
xTadPack.platformShapeInfo(), xTadPack.platformOffsets(),
zTadPack.platformShapeInfo(), zTadPack.platformOffsets());
@ -844,8 +845,10 @@ TEST_F(DeclarableOpsTests12, pullRows_2) {
#ifdef __CUDABLAS__
nativeStart[1] = (x.getContext()->getCudaStream());
#endif
pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer zBuf(z.dataBuffer());
pullRows(nativeStart, &xBuf, x.getShapeInfo(), x.specialShapeInfo(),
&zBuf, z.getShapeInfo(), z.specialShapeInfo(),
4, pidx,
xTadPack.platformShapeInfo(), xTadPack.platformOffsets(),
zTadPack.platformShapeInfo(), zTadPack.platformOffsets());

View File

@ -0,0 +1,94 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "testlayers.h"
#include <ops/declarable/CustomOperations.h>
#include <NDArray.h>
#include <ops/ops.h>
#include <GradCheck.h>
#include <array>
using namespace nd4j;
class DeclarableOpsTests17 : public testing::Test {
public:
DeclarableOpsTests17() {
printf("\n");
fflush(stdout);
}
};
TEST_F(DeclarableOpsTests17, test_sparse_to_dense_1) {
auto values = NDArrayFactory::create<float>({1.f, 2.f, 3.f});
auto shape = NDArrayFactory::create<Nd4jLong>({3, 3});
auto ranges = NDArrayFactory::create<Nd4jLong>({0,0, 1,1, 2,2});
auto def = NDArrayFactory::create<float>(0.f);
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {1.f,0.f,0.f, 0.f,2.f,0.f, 0.f,0.f,3.f});
nd4j::ops::compat_sparse_to_dense op;
auto result = op.execute({&ranges, &shape, &values, &def}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
delete result;
}
TEST_F(DeclarableOpsTests17, test_sparse_to_dense_2) {
auto values = NDArrayFactory::string('c', {3}, {"alpha", "beta", "gamma"});
auto shape = NDArrayFactory::create<Nd4jLong>({3, 3});
auto ranges = NDArrayFactory::create<Nd4jLong>({0,0, 1,1, 2,2});
auto def = NDArrayFactory::string("d");
auto exp = NDArrayFactory::string('c', {3, 3}, {"alpha","d","d", "d","beta","d", "d","d","gamma"});
nd4j::ops::compat_sparse_to_dense op;
auto result = op.execute({&ranges, &shape, &values, &def}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
delete result;
}
TEST_F(DeclarableOpsTests17, test_compat_string_split_1) {
auto x = NDArrayFactory::string('c', {2}, {"first string", "second"});
auto delimiter = NDArrayFactory::string(" ");
auto exp0 = NDArrayFactory::create<Nd4jLong>({0,0, 0,1, 1,0});
auto exp1 = NDArrayFactory::string('c', {3}, {"first", "string", "second"});
nd4j::ops::compat_string_split op;
auto result = op.execute({&x, &delimiter}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(2, result->size());
auto z0 = result->at(0);
auto z1 = result->at(1);
ASSERT_TRUE(exp0.isSameShape(z0));
ASSERT_TRUE(exp1.isSameShape(z1));
ASSERT_EQ(exp0, *z0);
ASSERT_EQ(exp1, *z1);
delete result;
}

View File

@ -0,0 +1,52 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "testlayers.h"
#include <ops/declarable/CustomOperations.h>
#include <NDArray.h>
#include <ops/ops.h>
#include <GradCheck.h>
#include <array>
using namespace nd4j;
class DeclarableOpsTests18 : public testing::Test {
public:
DeclarableOpsTests18() {
printf("\n");
fflush(stdout);
}
};
TEST_F(DeclarableOpsTests18, test_bitcast_1) {
auto x = NDArrayFactory::create<double>(0.23028551377579154);
auto z = NDArrayFactory::create<Nd4jLong>(0);
auto e = NDArrayFactory::create<Nd4jLong>(4597464930322771456L);
nd4j::ops::bitcast op;
auto status = op.execute({&x}, {&z}, {}, {(Nd4jLong) nd4j::DataType::INT64}, {});
ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(e, z);
}

View File

@ -0,0 +1,40 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "testlayers.h"
#include <ops/declarable/CustomOperations.h>
#include <NDArray.h>
#include <ops/ops.h>
#include <GradCheck.h>
#include <array>
using namespace nd4j;
class DeclarableOpsTests19 : public testing::Test {
public:
DeclarableOpsTests19() {
printf("\n");
fflush(stdout);
}
};

View File

@ -834,12 +834,17 @@ TEST_F(JavaInteropTests, Test_Reduce3_EdgeCase) {
auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {0,1});
NDArray::prepareSpecialUse({&z}, {&x, &y, &dims});
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer zBuf(z.dataBuffer());
OpaqueDataBuffer dimBuf(dims.dataBuffer());
execReduce3Tad(extraPointers, 2, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
execReduce3Tad(extraPointers, 2, &xBuf, x.shapeInfo(), x.specialShapeInfo(),
nullptr,
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
dims.buffer(), dims.shapeInfo(), dims.specialBuffer(), dims.specialShapeInfo(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
&yBuf, y.shapeInfo(), y.specialShapeInfo(),
&zBuf, z.shapeInfo(), z.specialShapeInfo(),
&dimBuf, dims.shapeInfo(), dims.specialShapeInfo(), packX.platformShapeInfo(),
packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
NDArray::registerSpecialUse({&z}, {&x, &y, &dims});
@ -981,10 +986,14 @@ TEST_F(JavaInteropTests, Test_Mixed_Add_1) {
NDArray::prepareSpecialUse({&arrayZ}, {&arrayX, &arrayY});
OpaqueDataBuffer xBuf(arrayX.dataBuffer());
OpaqueDataBuffer yBuf(arrayY.dataBuffer());
OpaqueDataBuffer zBuf(arrayZ.dataBuffer());
execPairwiseTransform(nullptr, pairwise::Add,
arrayX.buffer(), arrayX.shapeInfo(), arrayX.getSpecialBuffer(), arrayX.getSpecialShapeInfo(),
arrayY.buffer(), arrayY.shapeInfo(), arrayY.getSpecialBuffer(), arrayY.getSpecialShapeInfo(),
arrayZ.buffer(), arrayZ.shapeInfo(), arrayZ.getSpecialBuffer(), arrayZ.getSpecialShapeInfo(),
&xBuf, arrayX.shapeInfo(), arrayX.getSpecialShapeInfo(),
&yBuf, arrayY.shapeInfo(), arrayY.getSpecialShapeInfo(),
&zBuf, arrayZ.shapeInfo(), arrayZ.getSpecialShapeInfo(),
nullptr);
NDArray::registerSpecialUse({&arrayZ}, {&arrayX, &arrayY});
@ -1220,10 +1229,10 @@ TEST_F(JavaInteropTests, test_bfloat16_rng) {
auto z = NDArrayFactory::create<bfloat16>('c', {10});
RandomGenerator rng(119, 323841120L);
bfloat16 args[2] = {(bfloat16) 0.0f, (bfloat16) 1.0f};
execRandom(nullptr, nd4j::random::Ops::UniformDistribution, &rng, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), args);
OpaqueDataBuffer zBuf(z.dataBuffer());
execRandom(nullptr, nd4j::random::Ops::UniformDistribution, &rng, &zBuf, z.shapeInfo(), z.specialShapeInfo(), args);
//z.printIndexedBuffer("z");
ASSERT_TRUE(z.sumNumber().e<float>(0) > 0);
}
@ -1267,6 +1276,64 @@ TEST_F(JavaInteropTests, test_size_dtype_1) {
ASSERT_EQ(e, z);
}
TEST_F(JavaInteropTests, test_expandable_array_op_1) {
auto x = NDArrayFactory::string('c', {2}, {"first string", "second"});
auto d = NDArrayFactory::string(" ");
auto z0 = NDArrayFactory::create<Nd4jLong>('c', {6});
auto z1 = NDArrayFactory::string('c', {3}, {"", "", ""});
auto exp0 = NDArrayFactory::create<Nd4jLong>({0,0, 0,1, 1,0});
auto exp1 = NDArrayFactory::string('c', {3}, {"first", "string", "second"});
InteropDataBuffer iz0(z0.dataBuffer());
InteropDataBuffer iz1(z1.dataBuffer());
Context ctx(1);
ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
ctx.setInputArray(1, d.buffer(), d.shapeInfo(), d.specialBuffer(), d.specialShapeInfo());
ctx.setOutputArray(0, &iz0, z0.shapeInfo(), z0.specialShapeInfo());
ctx.setOutputArray(1, &iz1, z1.shapeInfo(), z1.specialShapeInfo());
nd4j::ops::compat_string_split op;
auto status = op.execute(&ctx);
ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(exp0, z0);
ASSERT_EQ(exp1, z1);
}
TEST_F(JavaInteropTests, test_workspace_backed_arrays_1) {
if (!Environment::getInstance()->isCPU())
return;
auto x = NDArrayFactory::create<double>('c', {4, 3, 4, 4});
auto y = NDArrayFactory::create<double>('c', {4, 3, 3, 3});
auto z = NDArrayFactory::create<double>('c', {4, 3, 4, 4});
double buffer[2048];
InteropDataBuffer ix(0, DataType::DOUBLE, false);
InteropDataBuffer iy(0, DataType::DOUBLE, false);
InteropDataBuffer iz(0, DataType::DOUBLE, false);
// we're imitating workspace-managed array here
ix.setPrimary(buffer + 64, x.lengthOf());
iy.setPrimary(buffer + 64 + x.lengthOf(), y.lengthOf());
iz.setPrimary(buffer + 64 + x.lengthOf() + y.lengthOf(), z.lengthOf());
Context ctx(1);
ctx.setInputArray(0, &ix, x.shapeInfo(), x.specialShapeInfo());
ctx.setInputArray(1, &iy, y.shapeInfo(), y.specialShapeInfo());
ctx.setOutputArray(0, &iz, z.shapeInfo(), z.specialShapeInfo());
ctx.setIArguments({2, 2, 1, 1, 0, 0, 1, 1, 0, 0, 0});
nd4j::ops::maxpool2d_bp op;
auto status = op.execute(&ctx);
ASSERT_EQ(Status::OK(), status);
}
/*
TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
auto pl = nd4j::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb");

View File

@ -470,12 +470,16 @@ TEST_F(LegacyOpsTests, Reduce3_2) {
auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {1});
NDArray::prepareSpecialUse({&z}, {&x, &y, &dim});
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer zBuf(z.dataBuffer());
OpaqueDataBuffer dimBuf(dim.dataBuffer());
execReduce3Tad(extraPointers, reduce3::CosineSimilarity,
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
nullptr, &yBuf, y.shapeInfo(), y.specialShapeInfo(),
&zBuf, z.shapeInfo(), z.specialShapeInfo(),
&dimBuf, dim.shapeInfo(), dim.specialShapeInfo(),
packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
NDArray::registerSpecialUse({&z}, {&x, &y, &dim});
@ -506,14 +510,17 @@ TEST_F(LegacyOpsTests, Reduce3_3) {
auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {1});
NDArray::prepareSpecialUse({&z}, {&x, &y, &dim});
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer zBuf(z.dataBuffer());
OpaqueDataBuffer dimBuf(dim.dataBuffer());
execReduce3Tad(extraPointers, reduce3::CosineDistance,
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
nullptr,
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
&yBuf, y.shapeInfo(), y.specialShapeInfo(),
&zBuf, z.shapeInfo(), z.specialShapeInfo(),
&dimBuf, dim.shapeInfo(), dim.specialShapeInfo(),
packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
ASSERT_EQ(e, z);
NDArray::registerSpecialUse({&z}, {&x, &y, &dim});
@ -543,14 +550,17 @@ TEST_F(LegacyOpsTests, Reduce3_4) {
auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {1});
NDArray::prepareSpecialUse({&z}, {&x, &y, &dim});
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer zBuf(z.dataBuffer());
OpaqueDataBuffer dimBuf(dim.dataBuffer());
execReduce3Tad(extraPointers, reduce3::CosineDistance,
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
nullptr,
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
&yBuf, y.shapeInfo(), y.specialShapeInfo(),
&zBuf, z.shapeInfo(), z.specialShapeInfo(),
&dimBuf, dim.shapeInfo(), dim.specialShapeInfo(),
packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
// z.printIndexedBuffer("z");
@ -583,13 +593,16 @@ TEST_F(LegacyOpsTests, Reduce3_5) {
NDArray::prepareSpecialUse({&z}, {&x, &y, &dim});
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer zBuf(z.dataBuffer());
OpaqueDataBuffer dimBuf(dim.dataBuffer());
execReduce3Tad(extraPointers, reduce3::CosineDistance,
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
nullptr,
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
&yBuf, y.shapeInfo(), y.specialShapeInfo(), &zBuf, z.shapeInfo(), z.specialShapeInfo(),
&dimBuf, dim.shapeInfo(), dim.specialShapeInfo(),
packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
NDArray::registerSpecialUse({&z}, {&x, &y, &dim});
@ -615,10 +628,15 @@ TEST_F(LegacyOpsTests, test_Reduce3_All_1) {
NDArray::prepareSpecialUse({&z}, {&x, &y});
execReduce3All(extraPointers, reduce3::EuclideanDistance, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer zBuf(z.dataBuffer());
OpaqueDataBuffer dimBuf(dim.dataBuffer());
execReduce3All(extraPointers, reduce3::EuclideanDistance, &xBuf, x.shapeInfo(), x.specialShapeInfo(),
nullptr, &yBuf, y.shapeInfo(), y.specialShapeInfo(),
&zBuf, z.shapeInfo(), z.specialShapeInfo(),
&dimBuf, dim.shapeInfo(), dim.specialShapeInfo(),
tadPackX.platformShapeInfo(), tadPackX.platformOffsets(),
tadPackY.platformShapeInfo(), tadPackY.platformOffsets());
@ -730,13 +748,16 @@ TEST_F(LegacyOpsTests, test_legacy_reduce_empty_4) {
auto z = NDArrayFactory::create<float>('c', {0, 2});
auto e = NDArrayFactory::create<float>('c', {0, 2});
InteropDataBuffer xdb(x.dataBuffer());
InteropDataBuffer ddb(d.dataBuffer());
InteropDataBuffer zdb(z.dataBuffer());
::execReduceSame2(nullptr, reduce::SameOps::Sum,
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
&xdb, x.shapeInfo(), x.specialShapeInfo(),
nullptr,
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
d.buffer(), d.shapeInfo(), d.specialBuffer(), d.specialShapeInfo());
&zdb, z.shapeInfo(), z.specialShapeInfo(),
&ddb, d.shapeInfo(), d.specialShapeInfo());
}

View File

@ -119,13 +119,15 @@ TEST_F(NativeOpsTests, ExecIndexReduce_1) {
#ifdef __CUDABLAS__
printf("Unsupported for cuda now.\n");
#else
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execIndexReduceScalar(nullptr,
indexreduce::IndexMax,
x.buffer(), x.shapeInfo(),
nullptr, nullptr,
&xBuf, x.shapeInfo(),
nullptr,
exp.buffer(), exp.shapeInfo(),
nullptr, nullptr);
nullptr,
&expBuf, exp.shapeInfo(),
nullptr);
ASSERT_TRUE(exp.e<Nd4jLong>(0) == 4LL);
#endif
@ -140,15 +142,18 @@ TEST_F(NativeOpsTests, ExecIndexReduce_2) {
printf("Unsupported for cuda now.\n");
#else
NDArray dimension = NDArrayFactory::create<int>({});
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
OpaqueDataBuffer dimensionBuf(dimension.dataBuffer());
::execIndexReduce(nullptr,
indexreduce::IndexMax,
x.buffer(), x.shapeInfo(),
nullptr, nullptr,
&xBuf, x.shapeInfo(), nullptr,
nullptr,
exp.buffer(), exp.shapeInfo(),
nullptr, nullptr,
dimension.buffer(), dimension.shapeInfo(),
nullptr, nullptr);
&expBuf, exp.shapeInfo(),
nullptr,
&dimensionBuf, dimension.shapeInfo(),
nullptr);
ASSERT_TRUE(exp.e<Nd4jLong>(0) == 24LL);
#endif
@ -166,16 +171,21 @@ TEST_F(NativeOpsTests, ExecBroadcast_1) {
#else
auto dimension = NDArrayFactory::create<int>('c', {1}, {1});
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
OpaqueDataBuffer dimBuf(dimension.dataBuffer());
::execBroadcast(nullptr,
broadcast::Add,
x.buffer(), x.shapeInfo(),
nullptr, nullptr,
y.buffer(), y.shapeInfo(),
nullptr, nullptr,
exp.buffer(), exp.shapeInfo(),
nullptr, nullptr,
dimension.buffer(), dimension.shapeInfo(),
nullptr, nullptr);
&xBuf, x.shapeInfo(),
nullptr,
&yBuf, y.shapeInfo(),
nullptr,
&expBuf, exp.shapeInfo(),
nullptr,
&dimBuf, dimension.shapeInfo(),
nullptr);
ASSERT_TRUE(exp.e<float>(0) == 3.);
#endif
@ -194,17 +204,18 @@ printf("Unsupported for cuda now.\n");
int dimd = 0;
auto dimension = NDArrayFactory::create<int>('c', {1}, {dimd});
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
OpaqueDataBuffer dimBuf(dimension.dataBuffer());
::execBroadcastBool(nullptr,
broadcast::EqualTo,
x.buffer(), x.shapeInfo(),
nullptr, nullptr,
y.buffer(), y.shapeInfo(),
nullptr, nullptr,
exp.buffer(), exp.shapeInfo(),
nullptr, nullptr,
nullptr,
dimension.buffer(), dimension.shapeInfo(),
nullptr, nullptr);
&xBuf, x.shapeInfo(), nullptr,
&yBuf, y.shapeInfo(), nullptr,
&expBuf, exp.shapeInfo(), nullptr, nullptr,
&dimBuf, dimension.shapeInfo(),
nullptr);
ASSERT_TRUE(exp.e<bool>(1) && !exp.e<bool>(0));
#endif
@ -219,14 +230,15 @@ TEST_F(NativeOpsTests, ExecPairwise_1) {
#ifdef __CUDABLAS__
printf("Unsupported for cuda now.\n");
#else
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execPairwiseTransform(nullptr,
pairwise::Add,
x.buffer(), x.shapeInfo(),
nullptr, nullptr,
y.buffer(), y.shapeInfo(),
nullptr, nullptr,
exp.buffer(), exp.shapeInfo(),
nullptr, nullptr,
&xBuf, x.shapeInfo(), nullptr,
&yBuf, y.shapeInfo(), nullptr,
&expBuf, exp.shapeInfo(), nullptr,
nullptr);
ASSERT_TRUE(exp.e<float>(5) == 8.);
#endif
@ -243,14 +255,15 @@ TEST_F(NativeOpsTests, ExecPairwise_2) {
#ifdef __CUDABLAS__
printf("Unsupported for cuda now.\n");
#else
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execPairwiseTransformBool(nullptr,
pairwise::And,
x.buffer(), x.shapeInfo(),
nullptr, nullptr,
y.buffer(), y.shapeInfo(),
nullptr, nullptr,
exp.buffer(), exp.shapeInfo(),
nullptr, nullptr,
&xBuf, x.shapeInfo(), nullptr,
&yBuf, y.shapeInfo(), nullptr,
&expBuf, exp.shapeInfo(), nullptr,
nullptr);
ASSERT_TRUE(exp.e<bool>(5) && !exp.e<bool>(4));
#endif
@ -266,14 +279,14 @@ TEST_F(NativeOpsTests, ReduceTest_1) {
printf("Unsupported for cuda now.\n");
#else
auto dimension = NDArrayFactory::create<int>('c', {1}, {1});
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execReduceFloat(nullptr,
reduce::Mean,
x.buffer(), x.shapeInfo(),
nullptr, nullptr,
&xBuf, x.shapeInfo(), nullptr,
nullptr,
exp.buffer(), exp.shapeInfo(),
nullptr, nullptr);
&expBuf, exp.shapeInfo(), nullptr);
// x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Reduce Mean");
ASSERT_TRUE(exp.e<float>(0) == 13.);
@ -289,14 +302,14 @@ TEST_F(NativeOpsTests, ReduceTest_2) {
#ifdef __CUDABLAS__
printf("Unsupported for cuda now.\n");
#else
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execReduceSame(nullptr,
reduce::Sum,
x.buffer(), x.shapeInfo(),
nullptr, nullptr,
&xBuf, x.shapeInfo(), nullptr,
nullptr,
exp.buffer(), exp.shapeInfo(),
nullptr, nullptr);
&expBuf, exp.shapeInfo(), nullptr);
// x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Reduce Sum");
ASSERT_TRUE(exp.e<float>(0) == 325.);
@ -312,14 +325,14 @@ TEST_F(NativeOpsTests, ReduceTest_3) {
#ifdef __CUDABLAS__
printf("Unsupported for cuda now.\n");
#else
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execReduceBool(nullptr,
reduce::All,
x.buffer(), x.shapeInfo(),
nullptr, nullptr,
&xBuf, x.shapeInfo(), nullptr,
nullptr,
exp.buffer(), exp.shapeInfo(),
nullptr, nullptr);
&expBuf, exp.shapeInfo(), nullptr);
// x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Reduce All");
ASSERT_TRUE(exp.e<bool>(0) == true);
@ -335,14 +348,14 @@ TEST_F(NativeOpsTests, ReduceTest_4) {
#ifdef __CUDABLAS__
printf("Unsupported for cuda now.\n");
#else
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execReduceLong(nullptr,
reduce::CountNonZero,
x.buffer(), x.shapeInfo(),
nullptr, nullptr,
&xBuf, x.shapeInfo(), nullptr,
nullptr,
exp.buffer(), exp.shapeInfo(),
nullptr, nullptr);
&expBuf, exp.shapeInfo(), nullptr);
// x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Reduce CountNonZero");
ASSERT_TRUE(exp.e<Nd4jLong>(0) == 25LL);
@ -359,15 +372,16 @@ TEST_F(NativeOpsTests, ReduceTest_5) {
printf("Unsupported for cuda now.\n");
#else
auto dimension = NDArrayFactory::create<int>({0, 1});
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
OpaqueDataBuffer dimBuf(dimension.dataBuffer());
::execReduceLong2(nullptr,
reduce::CountNonZero,
x.buffer(), x.shapeInfo(),
nullptr, nullptr,
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
nullptr,
exp.buffer(), exp.shapeInfo(),
nullptr, nullptr,
dimension.buffer(), dimension.shapeInfo(),
dimension.specialBuffer(), dimension.specialShapeInfo());
&expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
&dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo());
// x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Reduce CountNonZero");
ASSERT_TRUE(exp.e<Nd4jLong>(0) == 25LL);
@ -389,15 +403,17 @@ TEST_F(NativeOpsTests, ReduceTest_6) {
x.p(10, 0); x.p(11, 0);
x.p(15, 0); x.p(16, 0); x.p(17, 0);
x.p(20, 0); x.p(21, 0); x.p(22, 0); x.p(23, 0);
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer dimBuf(dimension.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execReduceLong2(nullptr,
reduce::CountNonZero,
x.buffer(), x.shapeInfo(),
nullptr, nullptr,
&xBuf, x.shapeInfo(), nullptr,
nullptr,
exp.buffer(), exp.shapeInfo(),
nullptr, nullptr,
dimension.buffer(), dimension.shapeInfo(),
dimension.specialBuffer(), dimension.specialShapeInfo());
&expBuf, exp.shapeInfo(), nullptr,
&dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo());
// x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Reduce CountNonZero");
ASSERT_TRUE(exp.equalsTo(z));
@ -421,15 +437,16 @@ TEST_F(NativeOpsTests, ReduceTest_7) {
x.linspace(1.0);
x.syncToDevice();
dimension.syncToHost();
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer dimBuf(dimension.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execReduceFloat2(extra,
reduce::Mean,
x.buffer(), x.shapeInfo(),
x.specialBuffer(), x.specialShapeInfo(),
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
nullptr,
exp.buffer(), exp.shapeInfo(),
exp.specialBuffer(), exp.specialShapeInfo(),
dimension.buffer(), dimension.shapeInfo(),
dimension.specialBuffer(), dimension.specialShapeInfo());
&expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
&dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo());
// x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Reduce Mean");
ASSERT_TRUE(exp.equalsTo(z));
@ -453,16 +470,16 @@ TEST_F(NativeOpsTests, ReduceTest_8) {
x.syncToDevice();
dimension.syncToHost();
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer dimBuf(dimension.dataBuffer());
OpaqueDataBuffer zBuf(z.dataBuffer());
::execReduceSame2(extra,
reduce::Sum,
x.buffer(), x.shapeInfo(),
x.specialBuffer(), x.specialShapeInfo(),
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
nullptr,
z.buffer(), z.shapeInfo(),
z.specialBuffer(), z.specialShapeInfo(),
dimension.buffer(), dimension.shapeInfo(),
dimension.specialBuffer(), dimension.specialShapeInfo());
&zBuf, z.shapeInfo(), z.specialShapeInfo(),
&dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo());
// x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Reduce Sum");
ASSERT_TRUE(exp.equalsTo(z));
@ -485,15 +502,17 @@ TEST_F(NativeOpsTests, ReduceTest_9) {
x.syncToDevice();
dimension.syncToHost();
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer dimBuf(dimension.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execReduceBool2(extra,
reduce::All,
x.buffer(), x.shapeInfo(),
x.specialBuffer(), x.specialShapeInfo(),
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
nullptr,
exp.buffer(), exp.shapeInfo(),
exp.specialBuffer(), exp.specialShapeInfo(),
dimension.buffer(), dimension.shapeInfo(),
dimension.specialBuffer(), dimension.specialShapeInfo());
&expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
&dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo());
// x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Reduce All");
ASSERT_TRUE(exp.equalsTo(z));
@ -518,15 +537,16 @@ TEST_F(NativeOpsTests, Reduce3Test_1) {
y.assign(2.);
x.syncToDevice();
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execReduce3(extra,
reduce3::Dot,
x.buffer(), x.shapeInfo(),
x.specialBuffer(), x.specialShapeInfo(),
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
nullptr,
y.buffer(), y.shapeInfo(),
y.specialBuffer(), y.specialShapeInfo(),
exp.buffer(), exp.shapeInfo(),
exp.specialBuffer(), exp.specialShapeInfo());
&yBuf, y.shapeInfo(), y.specialShapeInfo(),
&expBuf, exp.shapeInfo(), exp.specialShapeInfo());
//z.printIndexedBuffer("Z");
//exp.printIndexedBuffer("Reduce3 Dot");
ASSERT_TRUE(exp.equalsTo(z));
@ -551,15 +571,16 @@ TEST_F(NativeOpsTests, Reduce3Test_2) {
y.assign(2.);
x.syncToDevice();
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execReduce3Scalar(extra,
reduce3::Dot,
x.buffer(), x.shapeInfo(),
x.specialBuffer(), x.specialShapeInfo(),
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
nullptr,
y.buffer(), y.shapeInfo(),
y.specialBuffer(), y.specialShapeInfo(),
exp.buffer(), exp.shapeInfo(),
exp.specialBuffer(), exp.specialShapeInfo());
&yBuf, y.shapeInfo(), y.specialShapeInfo(),
&expBuf, exp.shapeInfo(), exp.specialShapeInfo());
// x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Reduce3 Dot");
ASSERT_TRUE(exp.equalsTo(z));
@ -585,17 +606,18 @@ TEST_F(NativeOpsTests, Reduce3Test_3) {
x.syncToDevice();
dimension.syncToHost();
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
OpaqueDataBuffer dimBuf(dimension.dataBuffer());
::execReduce3Tad(extra,
reduce3::Dot,
x.buffer(), x.shapeInfo(),
x.specialBuffer(), x.specialShapeInfo(),
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
nullptr,
y.buffer(), y.shapeInfo(),
y.specialBuffer(), y.specialShapeInfo(),
exp.buffer(), exp.shapeInfo(),
exp.specialBuffer(), exp.specialShapeInfo(),
dimension.buffer(), dimension.shapeInfo(),
dimension.specialBuffer(), dimension.specialShapeInfo(),
&yBuf, y.shapeInfo(), y.specialShapeInfo(),
&expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
&dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo(),
nullptr, nullptr, nullptr, nullptr);
// x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Reduce All");
@ -630,17 +652,18 @@ TEST_F(NativeOpsTests, Reduce3Test_4) {
auto hTADShapeInfoY = tadPackY.primaryShapeInfo();
auto hTADOffsetsY = tadPackY.primaryOffsets();
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
OpaqueDataBuffer dimBuf(dimension.dataBuffer());
::execReduce3All(extra,
reduce3::Dot,
x.buffer(), x.shapeInfo(),
x.specialBuffer(), x.specialShapeInfo(),
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
nullptr,
y.buffer(), y.shapeInfo(),
y.specialBuffer(), y.specialShapeInfo(),
exp.buffer(), exp.shapeInfo(),
exp.specialBuffer(), exp.specialShapeInfo(),
dimension.buffer(), dimension.shapeInfo(),
dimension.specialBuffer(), dimension.specialShapeInfo(),
&yBuf, y.shapeInfo(), y.specialShapeInfo(),
&expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
&dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo(),
hTADShapeInfoX, hTADOffsetsX, hTADShapeInfoY, hTADOffsetsY);
// x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Reduce All");
@ -667,14 +690,16 @@ TEST_F(NativeOpsTests, ScalarTest_1) {
//y.assign(2.);
x.syncToDevice();
z.syncToDevice();
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execScalar(extra,
scalar::Multiply,
x.buffer(), x.shapeInfo(),
x.specialBuffer(), x.specialShapeInfo(),
exp.buffer(), exp.shapeInfo(),
exp.specialBuffer(), exp.specialShapeInfo(),
y.buffer(), y.shapeInfo(),
y.specialBuffer(), y.specialShapeInfo(), nullptr);
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
&expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
&yBuf, y.shapeInfo(), y.specialShapeInfo(), nullptr);
// x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Reduce All");
ASSERT_TRUE(exp.equalsTo(z));
@ -700,14 +725,16 @@ TEST_F(NativeOpsTests, ScalarTest_2) {
//y.assign(2.);
x.syncToDevice();
z.syncToDevice();
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execScalarBool(extra,
scalar::GreaterThan,
x.buffer(), x.shapeInfo(),
x.specialBuffer(), x.specialShapeInfo(),
exp.buffer(), exp.shapeInfo(),
exp.specialBuffer(), exp.specialShapeInfo(),
y.buffer(), y.shapeInfo(),
y.specialBuffer(), y.specialShapeInfo(), nullptr);
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
&expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
&yBuf, y.shapeInfo(), y.specialShapeInfo(), nullptr);
// x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Reduce All");
ASSERT_TRUE(exp.e<bool>(5) == z.e<bool>(5) && exp.e<bool>(15) != z.e<bool>(15));
@ -726,13 +753,14 @@ TEST_F(NativeOpsTests, SummaryStatsScalarTest_1) {
printf("Unsupported for CUDA platform yet.\n");
return;
#endif
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execSummaryStatsScalar(extra,
variance::SummaryStatsVariance,
x.buffer(), x.shapeInfo(),
x.specialBuffer(), x.specialShapeInfo(),
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
nullptr,
exp.buffer(), exp.shapeInfo(),
exp.specialBuffer(), exp.specialShapeInfo(), false);
&expBuf, exp.shapeInfo(), exp.specialShapeInfo(), false);
// x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Standard Variance");
ASSERT_TRUE(exp.equalsTo(z));
@ -751,13 +779,13 @@ TEST_F(NativeOpsTests, SummaryStatsScalarTest_2) {
printf("Unsupported for CUDA platform yet.\n");
return;
#endif
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execSummaryStats(extra,
variance::SummaryStatsVariance,
x.buffer(), x.shapeInfo(),
x.specialBuffer(), x.specialShapeInfo(),
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
nullptr,
exp.buffer(), exp.shapeInfo(),
exp.specialBuffer(), exp.specialShapeInfo(), false);
&expBuf, exp.shapeInfo(), exp.specialShapeInfo(), false);
// x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Standard Variance");
ASSERT_TRUE(exp.equalsTo(z));
@ -777,15 +805,16 @@ TEST_F(NativeOpsTests, SummaryStatsScalarTest_3) {
return;
#endif
auto dimensions = NDArrayFactory::create<int>({0, 1});
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
OpaqueDataBuffer dimBuf(dimensions.dataBuffer());
::execSummaryStatsTad(extra,
variance::SummaryStatsVariance,
x.buffer(), x.shapeInfo(),
x.specialBuffer(), x.specialShapeInfo(),
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
nullptr,
exp.buffer(), exp.shapeInfo(),
exp.specialBuffer(), exp.specialShapeInfo(),
dimensions.buffer(), dimensions.shapeInfo(),
dimensions.specialBuffer(), dimensions.specialShapeInfo(),
&expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
&dimBuf, dimensions.shapeInfo(), dimensions.specialShapeInfo(),
false,
nullptr, nullptr);
// x.printIndexedBuffer("Input");
@ -807,13 +836,15 @@ TEST_F(NativeOpsTests, TransformTest_1) {
return;
#endif
z.linspace(1.);
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer zBuf(z.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execTransformFloat(extra,
transform::Sqrt,
x.buffer(), x.shapeInfo(),
x.specialBuffer(), x.specialShapeInfo(),
exp.buffer(), exp.shapeInfo(),
exp.specialBuffer(), exp.specialShapeInfo(),
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
&expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
nullptr);
// x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Sqrt is");
@ -834,13 +865,15 @@ TEST_F(NativeOpsTests, TransformTest_2) {
return;
#endif
z.linspace(1.);
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer zBuf(z.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execTransformSame(extra,
transform::Square,
z.buffer(), z.shapeInfo(),
z.specialBuffer(), z.specialShapeInfo(),
exp.buffer(), exp.shapeInfo(),
exp.specialBuffer(), exp.specialShapeInfo(),
&zBuf, z.shapeInfo(), z.specialShapeInfo(),
&expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
nullptr);
// x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Square is");
@ -864,13 +897,14 @@ TEST_F(NativeOpsTests, TransformTest_3) {
z.assign(true);
x.p(24, -25);
z.p(24, false);
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execTransformBool(extra,
transform::IsPositive,
x.buffer(), x.shapeInfo(),
x.specialBuffer(), x.specialShapeInfo(),
exp.buffer(), exp.shapeInfo(),
exp.specialBuffer(), exp.specialShapeInfo(),
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
&expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
nullptr);
// x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("IsPositive");
@ -894,13 +928,13 @@ TEST_F(NativeOpsTests, TransformTest_4) {
return;
#endif
//z.linspace(1.);
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
::execTransformStrict(extra,
transform::Cosine,
x.buffer(), x.shapeInfo(),
x.specialBuffer(), x.specialShapeInfo(),
exp.buffer(), exp.shapeInfo(),
exp.specialBuffer(), exp.specialShapeInfo(),
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
&expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
nullptr);
// x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Cosine");
@ -932,17 +966,18 @@ TEST_F(NativeOpsTests, ScalarTadTest_1) {
auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), dimensions, dimension.lengthOf());
auto tadPackZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.shapeInfo(), dimensions, dimension.lengthOf());
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
OpaqueDataBuffer dimBuf(dimension.dataBuffer());
::execScalarTad(extra,
scalar::Multiply,
x.buffer(), x.shapeInfo(),
x.specialBuffer(), x.specialShapeInfo(),
exp.buffer(), exp.shapeInfo(),
exp.specialBuffer(), exp.specialShapeInfo(),
y.buffer(), y.shapeInfo(),
y.specialBuffer(), y.specialShapeInfo(),
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
&expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
&yBuf, y.shapeInfo(), y.specialShapeInfo(),
nullptr,
dimension.buffer(), dimension.shapeInfo(),
dimension.specialBuffer(), dimension.specialShapeInfo(),
&dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo(),
tadPackX.primaryShapeInfo(), tadPackX.primaryOffsets(), tadPackZ.primaryShapeInfo(), tadPackZ.primaryOffsets());
// x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("Reduce All");
@ -977,17 +1012,21 @@ TEST_F(NativeOpsTests, ScalarTadTest_2) {
auto tadPackZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.shapeInfo(), dimensions, dimension.lengthOf());
z.assign(true);
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer expBuf(exp.dataBuffer());
OpaqueDataBuffer dimBuf(dimension.dataBuffer());
::execScalarBoolTad(extra,
scalar::And,
x.buffer(), x.shapeInfo(),
x.specialBuffer(), x.specialShapeInfo(),
exp.buffer(), exp.shapeInfo(),
exp.specialBuffer(), exp.specialShapeInfo(),
y.buffer(), y.shapeInfo(),
y.specialBuffer(), y.specialShapeInfo(),
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
&expBuf, exp.shapeInfo(),
exp.specialShapeInfo(),
&yBuf, y.shapeInfo(),
y.specialShapeInfo(),
nullptr,
dimension.buffer(), dimension.shapeInfo(),
dimension.specialBuffer(), dimension.specialShapeInfo(),
&dimBuf, dimension.shapeInfo(),
dimension.specialShapeInfo(),
tadPackX.primaryShapeInfo(), tadPackX.primaryOffsets(), tadPackZ.primaryShapeInfo(), tadPackZ.primaryOffsets());
// x.printIndexedBuffer("Input");
// exp.printIndexedBuffer("And");
@ -1095,9 +1134,11 @@ TEST_F(NativeOpsTests, PullRowsTest_1) {
#ifdef __CUDABLAS__
nativeStart[1] = (x.getContext()->getCudaStream());
#endif
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer zBuf(z.dataBuffer());
pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(),
z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
pullRows(nativeStart, &xBuf, x.getShapeInfo(), x.getSpecialShapeInfo(),
&zBuf, z.getShapeInfo(), z.specialShapeInfo(),
4, pidx,
xTadPack.platformShapeInfo(), xTadPack.platformOffsets(),
zTadPack.platformShapeInfo(), zTadPack.platformOffsets());
@ -1250,7 +1291,9 @@ TEST_F(NativeOpsTests, RandomTest_1) {
#endif
graph::RandomGenerator rng(1023, 119);
double p = 0.5;
::execRandom(extra, random::BernoulliDistribution, &rng, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &p);
OpaqueDataBuffer zBuf(z.dataBuffer());
::execRandom(extra, random::BernoulliDistribution, &rng, &zBuf, z.shapeInfo(), z.specialShapeInfo(), &p);
}
TEST_F(NativeOpsTests, RandomTest_2) {
@ -1264,7 +1307,10 @@ TEST_F(NativeOpsTests, RandomTest_2) {
x.linspace(0, 0.01);
graph::RandomGenerator rng(1023, 119);
double p = 0.5;
::execRandom2(extra, random::DropOut, &rng, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &p);
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer zBuf(z.dataBuffer());
::execRandom2(extra, random::DropOut, &rng, &xBuf, x.shapeInfo(), x.specialShapeInfo(), &zBuf, z.shapeInfo(), z.specialShapeInfo(), &p);
}
TEST_F(NativeOpsTests, RandomTest_3) {
@ -1280,7 +1326,12 @@ TEST_F(NativeOpsTests, RandomTest_3) {
x.linspace(1, -0.01);
graph::RandomGenerator rng(1023, 119);
double p = 0.5;
::execRandom3(extra, random::ProbablisticMerge, &rng, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &p);
OpaqueDataBuffer xBuf(x.dataBuffer());
OpaqueDataBuffer yBuf(y.dataBuffer());
OpaqueDataBuffer zBuf(z.dataBuffer());
::execRandom3(extra, random::ProbablisticMerge, &rng, &xBuf, x.shapeInfo(), x.specialShapeInfo(), &yBuf,
y.shapeInfo(), y.specialShapeInfo(), &zBuf, z.shapeInfo(), z.specialShapeInfo(), &p);
}
TEST_F(NativeOpsTests, RandomTest_4) {
@ -1316,6 +1367,10 @@ TEST_F(NativeOpsTests, SortTests_2) {
#ifdef __CUDABLAS__
extras[1] = LaunchContext::defaultContext()->getCudaStream();
#endif
// OpaqueDataBuffer xBuf(x.dataBuffer());
// OpaqueDataBuffer yBuf(y.dataBuffer());
// OpaqueDataBuffer expBuf(exp.dataBuffer());
// OpaqueDataBuffer dimBuf(exp.dataBuffer());
::sortByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
k.tickWriteDevice();
@ -1541,6 +1596,13 @@ TEST_F(NativeOpsTests, CalculateOutputShapeTests_2) {
::deleteShapeList((Nd4jPointer) shapeList);
}
TEST_F(NativeOpsTests, interop_databuffer_tests_1) {
auto idb = ::allocateDataBuffer(100, 10, false);
auto ptr = ::dbPrimaryBuffer(idb);
::deleteDataBuffer(idb);
}
//Uncomment when needed only - massive calculations
//TEST_F(NativeOpsTests, BenchmarkTests_1) {
//

View File

@ -91,3 +91,25 @@ TEST_F(StringTests, Basic_dup_1) {
delete dup;
}
TEST_F(StringTests, byte_length_test_1) {
std::string f("alpha");
auto array = NDArrayFactory::string(f);
ASSERT_EQ(f.length(), StringUtils::byteLength(array));
}
TEST_F(StringTests, byte_length_test_2) {
auto array = NDArrayFactory::string('c', {2}, {"alpha", "beta"});
ASSERT_EQ(9, StringUtils::byteLength(array));
}
TEST_F(StringTests, test_split_1) {
auto split = StringUtils::split("alpha beta gamma", " ");
ASSERT_EQ(3, split.size());
ASSERT_EQ(std::string("alpha"), split[0]);
ASSERT_EQ(std::string("beta"), split[1]);
ASSERT_EQ(std::string("gamma"), split[2]);
}

View File

@ -1,5 +1,6 @@
package org.nd4j.autodiff.listeners.debugging;
import lombok.val;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
@ -113,16 +114,16 @@ public class ExecDebuggingListener extends BaseListener {
if(co.tArgs() != null && co.tArgs().length > 0) {
sb.append("\n\ttArgs=").append(Arrays.toString(co.tArgs()));
}
INDArray[] inputs = co.inputArguments();
INDArray[] outputs = co.outputArguments();
val inputs = co.inputArguments();
val outputs = co.outputArguments();
if(inputs != null ) {
for (int i = 0; i < inputs.length; i++) {
sb.append("\n\tInput[").append(i).append("]=").append(inputs[i].shapeInfoToString());
for (int i = 0; i < inputs.size(); i++) {
sb.append("\n\tInput[").append(i).append("]=").append(inputs.get(i).shapeInfoToString());
}
}
if(outputs != null ) {
for (int i = 0; i < outputs.length; i++) {
sb.append("\n\tOutputs[").append(i).append("]=").append(outputs[i].shapeInfoToString());
for (int i = 0; i < outputs.size(); i++) {
sb.append("\n\tOutputs[").append(i).append("]=").append(outputs.get(i).shapeInfoToString());
}
}
} else {
@ -156,22 +157,22 @@ public class ExecDebuggingListener extends BaseListener {
if(co.tArgs() != null && co.tArgs().length > 0 ){
sb.append("op.addTArgument(").append(Arrays.toString(co.tArgs()).replaceAll("[\\[\\]]", "")).append(");\n");
}
INDArray[] inputs = co.inputArguments();
INDArray[] outputs = co.outputArguments();
val inputs = co.inputArguments();
val outputs = co.outputArguments();
if(inputs != null ) {
sb.append("INDArray[] inputs = new INDArray[").append(inputs.length).append("];\n");
for (int i = 0; i < inputs.length; i++) {
sb.append("INDArray[] inputs = new INDArray[").append(inputs.size()).append("];\n");
for (int i = 0; i < inputs.size(); i++) {
sb.append("inputs[").append(i).append("] = ");
sb.append(createString(inputs[i]))
sb.append(createString(inputs.get(i)))
.append(";\n");
}
sb.append("op.addInputArgument(inputs);\n");
}
if(outputs != null ) {
sb.append("INDArray[] outputs = new INDArray[").append(outputs.length).append("];\n");
for (int i = 0; i < outputs.length; i++) {
sb.append("INDArray[] outputs = new INDArray[").append(outputs.size()).append("];\n");
for (int i = 0; i < outputs.size(); i++) {
sb.append("outputs[").append(i).append("] = ");
sb.append(createString(outputs[i]))
sb.append(createString(outputs.get(i)))
.append(";\n");
}
sb.append("op.addOutputArgument(outputs);\n");

View File

@ -478,11 +478,11 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
}
throw new IllegalStateException(s);
}
return ((Assert) op).outputArguments();
return ((Assert) op).outputArguments().toArray(new INDArray[0]);
} else if (op instanceof CustomOp) {
CustomOp c = (CustomOp) op;
Nd4j.exec(c);
return c.outputArguments();
return c.outputArguments().toArray(new INDArray[0]);
} else if (op instanceof Op) {
Op o = (Op) op;
Nd4j.exec(o);

View File

@ -457,7 +457,7 @@ public class OpValidation {
for (int i = 0; i < testCase.testFns().size(); i++) {
String error;
try {
error = testCase.testFns().get(i).apply(testCase.op().outputArguments()[i]);
error = testCase.testFns().get(i).apply(testCase.op().outputArguments().get(i));
} catch (Throwable t) {
throw new IllegalStateException("Exception thrown during op output validation for output " + i, t);
}

View File

@ -1,6 +1,7 @@
package org.nd4j.autodiff.validation.listeners;
import lombok.Getter;
import lombok.val;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.Operation;
@ -50,12 +51,12 @@ public class NonInplaceValidationListener extends BaseListener {
opInputs = new INDArray[]{o.x().dup(), o.y().dup()};
}
} else if(op.getOp() instanceof DynamicCustomOp){
INDArray[] arr = ((DynamicCustomOp) op.getOp()).inputArguments();
opInputs = new INDArray[arr.length];
opInputsOrig = new INDArray[arr.length];
for( int i=0; i<arr.length; i++ ){
opInputsOrig[i] = arr[i];
opInputs[i] = arr[i].dup();
val arr = ((DynamicCustomOp) op.getOp()).inputArguments();
opInputs = new INDArray[arr.size()];
opInputsOrig = new INDArray[arr.size()];
for( int i=0; i<arr.size(); i++ ){
opInputsOrig[i] = arr.get(i);
opInputs[i] = arr.get(i).dup();
}
} else {
throw new IllegalStateException("Unknown op type: " + op.getOp().getClass());

View File

@ -589,6 +589,10 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.random.impl.Range.class,
org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution.class,
org.nd4j.linalg.api.ops.random.impl.UniformDistribution.class,
org.nd4j.linalg.api.ops.util.PrintAffinity.class,
org.nd4j.linalg.api.ops.util.PrintVariable.class,
org.nd4j.linalg.api.ops.compat.CompatSparseToDense.class,
org.nd4j.linalg.api.ops.compat.CompatStringSplit.class,
org.nd4j.linalg.api.ops.custom.AdjustContrast.class,
org.nd4j.linalg.api.ops.custom.AdjustContrastV2.class,
org.nd4j.linalg.api.ops.custom.HsvToRgb.class,

View File

@ -73,7 +73,7 @@ public class ActivationPReLU extends BaseActivationFunction {
preluBp.addIntegerArguments(axis);
}
}
Nd4j.getExecutioner().execAndReturn(preluBp.build());
Nd4j.exec(preluBp.build());
in.assign(outTemp);
return new Pair<>(in, dLdalpha);
}

View File

@ -23,7 +23,6 @@ import com.google.flatbuffers.FlatBufferBuilder;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import net.ericaro.neoitertools.Generator;
import org.apache.commons.math3.util.FastMath;
import org.bytedeco.javacpp.BytePointer;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
@ -998,14 +997,14 @@ public abstract class BaseNDArray implements INDArray, Iterable {
}
}
Pair<DataBuffer, DataBuffer> tadInfo =
Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(this, dimension);
Pair<DataBuffer, DataBuffer> tadInfo = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(this, dimension);
DataBuffer shapeInfo = tadInfo.getFirst();
val shape = Shape.shape(shapeInfo);
val stride = Shape.stride(shapeInfo).asLong();
val jShapeInfo = shapeInfo.asLong();
val shape = Shape.shape(jShapeInfo);
val stride = Shape.stride(jShapeInfo);
long offset = offset() + tadInfo.getSecond().getLong(index);
val ews = shapeInfo.getLong(shapeInfo.getLong(0) * 2 + 2);
char tadOrder = (char) shapeInfo.getInt(shapeInfo.getLong(0) * 2 + 3);
val ews = shapeInfo.getLong(jShapeInfo[0] * 2 + 2);
char tadOrder = (char) shapeInfo.getInt(jShapeInfo[0] * 2 + 3);
val toTad = Nd4j.create(data(), shape, stride, offset, ews, tadOrder);
return toTad;
}
@ -2217,9 +2216,10 @@ public abstract class BaseNDArray implements INDArray, Iterable {
if(isEmpty() || isS())
return false;
return Shape.offset(jvmShapeInfo.javaShapeInformation) > 0
|| (length() < data().length() && data.dataType() != DataType.INT)
|| data().originalDataBuffer() != null;
val c2 = (length() < data().length() && data.dataType() != DataType.INT);
val c3 = (data().originalDataBuffer() != null && data != data.originalDataBuffer());
return c2 || c3;
}
@Override
@ -3585,6 +3585,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
case DOUBLE:
case FLOAT:
case HALF:
case BFLOAT16:
return getDouble(i);
case LONG:
case INT:
@ -3592,6 +3593,9 @@ public abstract class BaseNDArray implements INDArray, Iterable {
case UBYTE:
case BYTE:
case BOOL:
case UINT64:
case UINT32:
case UINT16:
return getLong(i);
case UTF8:
case COMPRESSED:
@ -4350,29 +4354,30 @@ public abstract class BaseNDArray implements INDArray, Iterable {
//epsilon equals
if (isScalar() && n.isScalar()) {
if (data.dataType() == DataType.FLOAT) {
double val = getDouble(0);
double val2 = n.getDouble(0);
if (isZ()) {
val val = getLong(0);
val val2 = n.getLong(0);
return val == val2;
} else if (isR()) {
val val = getDouble(0);
val val2 = n.getDouble(0);
if (Double.isNaN(val) != Double.isNaN(val2))
return false;
return Math.abs(val - val2) < eps;
} else {
double val = getDouble(0);
double val2 = n.getDouble(0);
} else if (isB()) {
val val = getInt(0);
val val2 = n.getInt(0);
if (Double.isNaN(val) != Double.isNaN(val2))
return false;
return Math.abs(val - val2) < eps;
return val == val2;
}
} else if (isVector() && n.isVector()) {
EqualsWithEps op = new EqualsWithEps(this, n, eps);
Nd4j.getExecutioner().exec(op);
double diff = op.z().getDouble(0);
val op = new EqualsWithEps(this, n, eps);
Nd4j.exec(op);
val diff = op.z().getDouble(0);
return diff < 0.5;
}
@ -4750,8 +4755,8 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return this;
checkArrangeArray(rearrange);
int[] newShape = doPermuteSwap(shapeOf(), rearrange);
int[] newStride = doPermuteSwap(strideOf(), rearrange);
val newShape = doPermuteSwap(shape(), rearrange);
val newStride = doPermuteSwap(stride(), rearrange);
char newOrder = Shape.getOrder(newShape, newStride, 1);
@ -4777,23 +4782,11 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return this;
checkArrangeArray(rearrange);
val newShape = doPermuteSwap(Shape.shapeOf(shapeInfo), rearrange);
val newStride = doPermuteSwap(Shape.stride(shapeInfo), rearrange);
val newShape = doPermuteSwap(shape(), rearrange);
val newStride = doPermuteSwap(stride(), rearrange);
char newOrder = Shape.getOrder(newShape, newStride, 1);
//Set the shape information of this array: shape, stride, order.
//Shape info buffer: [rank, [shape], [stride], offset, elementwiseStride, order]
/*for( int i=0; i<rank; i++ ){
shapeInfo.put(1+i,newShape[i]);
shapeInfo.put(1+i+rank,newStride[i]);
}
shapeInfo.put(3+2*rank,newOrder);
*/
val ews = shapeInfo.get(2 * rank + 2);
/*
if (ews < 1 && !attemptedToFindElementWiseStride)
throw new RuntimeException("EWS is -1");
*/
val si = Nd4j.getShapeInfoProvider().createShapeInformation(newShape, newStride, ews, newOrder, dataType(), isEmpty());
setShapeInformation(si);
@ -4813,6 +4806,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
}
@Deprecated
protected long[] doPermuteSwap(LongBuffer shape, int[] rearrange) {
val ret = new long[rearrange.length];
for (int i = 0; i < rearrange.length; i++) {
@ -4821,6 +4815,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return ret;
}
@Deprecated
protected int[] doPermuteSwap(IntBuffer shape, int[] rearrange) {
int[] ret = new int[rearrange.length];
for (int i = 0; i < rearrange.length; i++) {
@ -4829,11 +4824,20 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return ret;
}
@Deprecated
protected int[] doPermuteSwap(DataBuffer shape, int[] rearrange) {
int[] ret = new int[rearrange.length];
for (int i = 0; i < rearrange.length; i++) {
ret[i] = shape.getInt(rearrange[i]);
}
return ret;
}
protected long[] doPermuteSwap(long[] shape, int[] rearrange) {
val ret = new long[rearrange.length];
for (int i = 0; i < rearrange.length; i++) {
ret[i] = shape[rearrange[i]];
}
return ret;
}
@ -5413,29 +5417,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
}
protected int stringBuffer(FlatBufferBuilder builder, DataBuffer buffer) {
Preconditions.checkArgument(buffer.dataType() == DataType.UTF8, "This method can be called on UTF8 buffers only");
try {
ByteArrayOutputStream bos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(bos);
val numWords = this.length();
val ub = (Utf8Buffer) buffer;
// writing length first
val t = length();
val ptr = (BytePointer) ub.pointer();
// now write all strings as bytes
for (int i = 0; i < ub.length(); i++) {
dos.writeByte(ptr.get(i));
}
val bytes = bos.toByteArray();
return FlatArray.createBufferVector(builder, bytes);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
protected abstract int stringBuffer(FlatBufferBuilder builder, DataBuffer buffer);
@Override
public int toFlatArray(FlatBufferBuilder builder) {
@ -5543,13 +5525,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return !any();
}
@Override
public String getString(long index) {
if (!isS())
throw new UnsupportedOperationException("This method is usable only on String dataType, but got [" + this.dataType() + "]");
return ((Utf8Buffer) data).getString(index);
}
/**
* Validate that the operation is being applied on a numerical array (not boolean or utf8).

View File

@ -47,12 +47,9 @@ public interface CustomOp {
*/
boolean isInplaceCall();
List<INDArray> outputArguments();
INDArray[] outputArguments();
INDArray[] inputArguments();
List<INDArray> inputArguments();
long[] iArgs();

View File

@ -261,19 +261,13 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
}
@Override
public INDArray[] outputArguments() {
if (!outputArguments.isEmpty()) {
return outputArguments.toArray(new INDArray[0]);
}
return new INDArray[0];
public List<INDArray> outputArguments() {
return outputArguments;
}
@Override
public INDArray[] inputArguments() {
if (!inputArguments.isEmpty())
return inputArguments.toArray(new INDArray[0]);
return new INDArray[0];
public List<INDArray> inputArguments() {
return inputArguments;
}
@Override
@ -367,10 +361,10 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
for (int i = 0; i < args.length; i++) {
// it's possible to get into situation where number of args > number of arrays AT THIS MOMENT
if (i >= arrsSoFar.length)
if (i >= arrsSoFar.size())
continue;
if (!Arrays.equals(args[i].getShape(), arrsSoFar[i].shape()))
if (!Arrays.equals(args[i].getShape(), arrsSoFar.get(i).shape()))
throw new ND4JIllegalStateException("Illegal array passed in as argument [" + i + "]. Expected shape " + Arrays.toString(args[i].getShape()) + " and received array with shape " + Arrays.toString(arg[i].shape()));
}
}

View File

@ -0,0 +1,51 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.compat;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
* This is a wrapper for SparseToDense op that impelements corresponding TF operation
*
* @author raver119@gmail.com
*/
public class CompatSparseToDense extends DynamicCustomOp {
public CompatSparseToDense() {
//
}
public CompatSparseToDense(INDArray indices, INDArray shape, INDArray values) {
Preconditions.checkArgument(shape.isZ() && indices.isZ(), "Shape & indices arrays must have one integer data types");
inputArguments.add(indices);
inputArguments.add(shape);
inputArguments.add(values);
}
public CompatSparseToDense(INDArray indices, INDArray shape, INDArray values, INDArray defaultVaule) {
this(indices, shape, values);
Preconditions.checkArgument(defaultVaule.dataType() == values.dataType(), "Values array must have the same data type as defaultValue array");
inputArguments.add(defaultVaule);
}
@Override
public String opName() {
return "compat_sparse_to_dense";
}
}

View File

@ -0,0 +1,51 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.compat;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
* This is a wrapper for StringSplit op that impelements corresponding TF operation
*
* @author raver119@gmail.com
*/
public class CompatStringSplit extends DynamicCustomOp {
public CompatStringSplit() {
//
}
public CompatStringSplit(INDArray strings, INDArray delimiter) {
Preconditions.checkArgument(strings.isS() && delimiter.isS(), "Input arrays must have one of UTF types");
inputArguments.add(strings);
inputArguments.add(delimiter);
}
public CompatStringSplit(INDArray strings, INDArray delimiter, INDArray indices, INDArray values) {
this(strings, delimiter);
outputArguments.add(indices);
outputArguments.add(values);
}
@Override
public String opName() {
return "compat_string_split";
}
}

View File

@ -107,12 +107,12 @@ public class ScatterUpdate implements CustomOp {
}
@Override
public INDArray[] outputArguments() {
public List<INDArray> outputArguments() {
return op.outputArguments();
}
@Override
public INDArray[] inputArguments() {
public List<INDArray> inputArguments() {
return op.inputArguments();
}

View File

@ -23,7 +23,6 @@ import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.Utf8Buffer;
import org.nd4j.linalg.api.environment.Nd4jEnvironment;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -172,7 +171,7 @@ public class DefaultOpExecutioner implements OpExecutioner {
@Override
public INDArray[] exec(CustomOp op) {
return execAndReturn(op).outputArguments();
return execAndReturn(op).outputArguments().toArray(new INDArray[0]);
}
@Override
@ -822,7 +821,7 @@ public class DefaultOpExecutioner implements OpExecutioner {
}
@Override
public String getString(Utf8Buffer buffer, long index) {
public String getString(DataBuffer buffer, long index) {
throw new UnsupportedOperationException();
}

View File

@ -20,7 +20,6 @@ import lombok.NonNull;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.Utf8Buffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ndarray.INDArrayStatistics;
import org.nd4j.linalg.api.ops.*;
@ -32,8 +31,6 @@ import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.TadPack;
import org.nd4j.linalg.cache.TADManager;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.profiler.OpProfiler;
import org.nd4j.linalg.profiler.ProfilerConfig;
import java.util.List;
@ -411,7 +408,7 @@ public interface OpExecutioner {
* @param index
* @return
*/
String getString(Utf8Buffer buffer, long index);
String getString(DataBuffer buffer, long index);
/**
* Temporary hook

View File

@ -0,0 +1,43 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.util;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
/**
* This is a wrapper for PrintAffinity op that just prints out affinity & locality status of INDArray
*
* @author raver119@gmail.com
*/
public class PrintAffinity extends DynamicCustomOp {
public PrintAffinity() {
//
}
public PrintAffinity(INDArray array) {
inputArguments.add(array);
}
@Override
public String opName() {
return "print_affinity";
}
}

View File

@ -0,0 +1,66 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.util;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
/**
* This is a wrapper for PrintVariable op that just prints out Variable to the stdout
*
* @author raver119@gmail.com
*/
public class PrintVariable extends DynamicCustomOp {
public PrintVariable() {
//
}
public PrintVariable(INDArray array, boolean printSpecial) {
inputArguments.add(array);
bArguments.add(printSpecial);
}
public PrintVariable(INDArray array) {
this(array, false);
}
public PrintVariable(INDArray array, String message, boolean printSpecial) {
this(array, Nd4j.create(message), printSpecial);
}
public PrintVariable(INDArray array, String message) {
this(array, Nd4j.create(message), false);
}
public PrintVariable(INDArray array, INDArray message, boolean printSpecial) {
this(array, printSpecial);
Preconditions.checkArgument(message.isS(), "Message argument should have String data type, but got [" + message.dataType() +"] instead");
inputArguments.add(message);
}
public PrintVariable(INDArray array, INDArray message) {
this(array, message, false);
}
@Override
public String opName() {
return "print_variable";
}
}

View File

@ -89,6 +89,11 @@ public class CompressedDataBuffer extends BaseDataBuffer {
// no-op
}
@Override
public Pointer addressPointer() {
return pointer;
}
/**
* Drop-in replacement wrapper for BaseDataBuffer.read() method, aware of CompressedDataBuffer
* @param s
@ -194,6 +199,15 @@ public class CompressedDataBuffer extends BaseDataBuffer {
*/
@Override
public DataBuffer create(int[] data) {
throw new UnsupportedOperationException("This operation isn't supported for CompressedDataBuffer");
throw new UnsupportedOperationException("This method isn't supported by CompressedDataBuffer");
}
public void pointerIndexerByCurrentType(DataType currentType) {
throw new UnsupportedOperationException("This method isn't supported by CompressedDataBuffer");
}
@Override
public DataBuffer reallocate(long length) {
throw new UnsupportedOperationException("This method isn't supported by CompressedDataBuffer");
}
}

View File

@ -98,7 +98,7 @@ public class Convolution {
.build();
Nd4j.getExecutioner().execAndReturn(col2Im);
return col2Im.outputArguments()[0];
return col2Im.outputArguments().get(0);
}
public static INDArray col2im(INDArray col, INDArray z, int sH, int sW, int pH, int pW, int kH, int kW,
@ -187,7 +187,7 @@ public class Convolution {
.build()).build();
Nd4j.getExecutioner().execAndReturn(im2col);
return im2col.outputArguments()[0];
return im2col.outputArguments().get(0);
}
public static INDArray im2col(INDArray img, int kh, int kw, int sy, int sx, int ph, int pw, int dH, int dW, boolean isSameMode,
@ -208,7 +208,7 @@ public class Convolution {
.build()).build();
Nd4j.getExecutioner().execAndReturn(im2col);
return im2col.outputArguments()[0];
return im2col.outputArguments().get(0);
}
/**
@ -298,7 +298,7 @@ public class Convolution {
.build()).build();
Nd4j.getExecutioner().execAndReturn(im2col);
return im2col.outputArguments()[0];
return im2col.outputArguments().get(0);
}
/**

View File

@ -40,7 +40,6 @@ import org.nd4j.graph.FlatArray;
import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.buffer.*;
import org.nd4j.linalg.api.buffer.factory.DataBufferFactory;
import org.nd4j.linalg.api.buffer.factory.DefaultDataBufferFactory;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.concurrency.BasicAffinityManager;
@ -1044,16 +1043,7 @@ public class Nd4j {
* @return the created buffer
*/
public static DataBuffer createBuffer(ByteBuffer buffer, DataType type, int length, long offset) {
switch (type) {
case INT:
return DATA_BUFFER_FACTORY_INSTANCE.createInt(offset, buffer, length);
case DOUBLE:
return DATA_BUFFER_FACTORY_INSTANCE.createDouble(offset, buffer, length);
case FLOAT:
return DATA_BUFFER_FACTORY_INSTANCE.createFloat(offset, buffer, length);
default:
throw new IllegalArgumentException("Illegal opType " + type);
}
return DATA_BUFFER_FACTORY_INSTANCE.create(buffer, type, length, offset);
}
/**
@ -1336,38 +1326,9 @@ public class Nd4j {
* @return the created buffer
*/
public static DataBuffer createBuffer(ByteBuffer buffer, DataType type, int length) {
switch (type) {
case INT:
return DATA_BUFFER_FACTORY_INSTANCE.createInt(buffer, length);
case LONG:
return DATA_BUFFER_FACTORY_INSTANCE.createLong(buffer, length);
case DOUBLE:
return DATA_BUFFER_FACTORY_INSTANCE.createDouble(buffer, length);
case FLOAT:
return DATA_BUFFER_FACTORY_INSTANCE.createFloat(buffer, length);
case HALF:
return DATA_BUFFER_FACTORY_INSTANCE.createHalf(buffer, length);
default:
throw new IllegalArgumentException("Illegal opType " + type);
}
return createBuffer(buffer, type, length, 0);
}
/**
* Create a buffer based on the data opType
*
* @param data the data to create the buffer with
* @return the created buffer
*/
public static DataBuffer createBuffer(byte[] data, int length) {
DataBuffer ret;
if (dataType() == DataType.DOUBLE)
ret = DATA_BUFFER_FACTORY_INSTANCE.createDouble(data, length);
else if (dataType() == DataType.HALF)
ret = DATA_BUFFER_FACTORY_INSTANCE.createHalf(data, length);
else
ret = DATA_BUFFER_FACTORY_INSTANCE.createFloat(data, length);
return ret;
}
/**
* Create a buffer equal of length prod(shape)
@ -2206,6 +2167,7 @@ public class Nd4j {
private static String writeStringForArray(INDArray write) {
if(write.isView() || !Shape.hasDefaultStridesForShape(write))
write = write.dup();
String format = "0.000000000000000000E0";
return "{\n" +
@ -3927,16 +3889,6 @@ public class Nd4j {
return create(shape, stride);
}
/**
* Creates an ndarray with the specified shape
*
* @param rows the rows of the ndarray
* @param columns the columns of the ndarray
* @return the instance
*/
public static INDArray create(int rows, int columns) {
return create(rows, columns, order());
}
/**
* Creates an ndarray with the specified shape
@ -4386,13 +4338,6 @@ public class Nd4j {
return createUninitialized(shape, Nd4j.order());
}
/**
* See {@link #createUninitialized(long)}
*/
public static INDArray createUninitialized(int length) {
return createUninitialized((long)length);
}
/**
* This method creates an *uninitialized* ndarray of specified length and default ordering.
*
@ -4428,37 +4373,6 @@ public class Nd4j {
////////////////////// OTHER ///////////////////////////////
/**
* Creates a 2D array with specified number of rows, columns initialized with zero.
*
* @param rows number of rows.
* @param columns number of columns.
* @return the created array.
*/
public static INDArray zeros(long rows, long columns) {
return INSTANCE.zeros(rows, columns);
}
/**
* Creates a 1D array with the specified number of columns initialized with zero.
*
* @param columns number of columns.
* @return the created array
*/
public static INDArray zeros(int columns) {
return INSTANCE.zeros(columns);
}
/**
* Creates a 1D array with the specified data tyoe and number of columns initialized with zero.
*
* @param dataType data type.
* @param columns number of columns.
* @return the created array.
*/
public static INDArray zeros(DataType dataType, int columns) {
return INSTANCE.create(dataType, new long[]{columns}, 'c', Nd4j.getMemoryManager().getCurrentWorkspace());
}
/**
* Creates an array with the specified data tyoe and shape initialized with zero.
@ -4468,7 +4382,10 @@ public class Nd4j {
* @return the created array.
*/
public static INDArray zeros(DataType dataType, @NonNull long... shape) {
return INSTANCE.create(dataType, shape, 'c', Nd4j.getMemoryManager().getCurrentWorkspace());
if(shape.length == 0)
return Nd4j.scalar(dataType, 0);
return INSTANCE.create(dataType, shape, Nd4j.order(), Nd4j.getMemoryManager().getCurrentWorkspace());
}
/**
@ -4588,31 +4505,6 @@ public class Nd4j {
return INSTANCE.valueArrayOf(rows, columns, value);
}
/**
* Creates a row vector with the specified number of columns
*
* @param rows the number of rows in the matrix
* @param columns the columns of the ndarray
* @return the created ndarray
*/
public static INDArray ones(int rows, int columns) {
return INSTANCE.ones(rows, columns);
}
/**
* Create a 2D array with the given rows, columns and data type initialised with ones.
*
* @param dataType data type
* @param rows rows of the new array.
* @param columns columns of the new arrau.
* @return the created array
*/
public static INDArray ones(DataType dataType, int rows, int columns) {
INDArray ret = INSTANCE.createUninitialized(dataType, new long[]{rows, columns}, Nd4j.order(), Nd4j.getMemoryManager().getCurrentWorkspace());
ret.assign(1);
return ret;
}
/**
* Empty like
*
@ -4817,8 +4709,7 @@ public class Nd4j {
for (int idx : indexes) {
if (idx < 0 || idx >= source.shape()[source.rank() - sourceDimension - 1]) {
throw new IllegalStateException(
"Index can't be < 0 and >= " + source.shape()[source.rank() - sourceDimension - 1]);
throw new IllegalStateException("Index can't be < 0 and >= " + source.shape()[source.rank() - sourceDimension - 1]);
}
}
@ -5186,7 +5077,7 @@ public class Nd4j {
pp.toString(NDARRAY_FACTORY_CLASS));
Class<? extends ConvolutionInstance> convolutionInstanceClazz = (Class<? extends ConvolutionInstance>) Class
.forName(pp.toString(CONVOLUTION_OPS, DefaultConvolutionInstance.class.getName()));
String defaultName = pp.toString(DATA_BUFFER_OPS, DefaultDataBufferFactory.class.getName());
String defaultName = pp.toString(DATA_BUFFER_OPS, "org.nd4j.linalg.cpu.nativecpu.buffer.DefaultDataBufferFactory");
Class<? extends DataBufferFactory> dataBufferFactoryClazz = (Class<? extends DataBufferFactory>) Class
.forName(pp.toString(DATA_BUFFER_OPS, defaultName));
Class<? extends BaseShapeInfoProvider> shapeInfoProviderClazz = (Class<? extends BaseShapeInfoProvider>) Class
@ -5871,7 +5762,7 @@ public class Nd4j {
arr[e] = sb.get(e + pos);
}
val buffer = new Utf8Buffer(arr, prod);
val buffer = DATA_BUFFER_FACTORY_INSTANCE.createUtf8Buffer(arr, prod);
return Nd4j.create(buffer, shapeOf);
} catch (Exception e) {
throw new RuntimeException(e);

View File

@ -30,6 +30,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
/**
* This class provides unified management for Deallocatable resources
@ -43,6 +44,8 @@ public class DeallocatorService {
private Map<String, DeallocatableReference> referenceMap = new ConcurrentHashMap<>();
private List<List<ReferenceQueue<Deallocatable>>> deviceMap = new ArrayList<>();
private AtomicLong counter = new AtomicLong(0);
public DeallocatorService() {
// we need to have at least 2 threads, but for CUDA we'd need at least numDevices threads, due to thread->device affinity
int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
@ -69,6 +72,10 @@ public class DeallocatorService {
}
}
public long nextValue() {
return counter.incrementAndGet();
}
/**
* This method adds Deallocatable object instance to tracking system
*

View File

@ -17,10 +17,10 @@
package org.nd4j.serde.jackson.shaded;
import org.nd4j.linalg.api.buffer.Utf8Buffer;
import lombok.val;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.serde.base64.Nd4jBase64;
import org.nd4j.shade.jackson.core.JsonGenerator;
import org.nd4j.shade.jackson.databind.JsonSerializer;
import org.nd4j.shade.jackson.databind.SerializerProvider;
@ -77,10 +77,9 @@ public class NDArrayTextSerializer extends JsonSerializer<INDArray> {
jg.writeNumber(v);
break;
case UTF8:
Utf8Buffer utf8B = ((Utf8Buffer)arr.data());
long n = utf8B.getNumWords();
val n = arr.length();
for( int j=0; j<n; j++ ) {
String s = utf8B.getString(j);
String s = arr.getString(j);
jg.writeString(s);
}
break;

View File

@ -16,11 +16,8 @@
package org.nd4j.nativeblas;
import lombok.val;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.Cast;
import org.bytedeco.javacpp.indexer.LongIndexer;
import org.nd4j.linalg.api.buffer.Utf8Buffer;
/**
@ -53,14 +50,12 @@ public interface NativeOps {
*/
void execIndexReduceScalar(PointerPointer extraPointers,
int opNum,
Pointer x,
OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dX,
@Cast("Nd4jLong *") LongPointer dXShapeInfo,
Pointer extraParams,
Pointer z,
OpaqueDataBuffer z,
@Cast("Nd4jLong *") LongPointer zShapeInfo,
Pointer dZ,
@Cast("Nd4jLong *") LongPointer dZShapeInfo);
/**
@ -75,17 +70,16 @@ public interface NativeOps {
*/
void execIndexReduce(PointerPointer extraPointers,
int opNum,
Pointer x,
OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dX,
@Cast("Nd4jLong *") LongPointer dXShapeInfo,
Pointer extraParams,
Pointer result,
OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfoBuffer,
Pointer dResult,
@Cast("Nd4jLong *") LongPointer dResultShapeInfoBuffer,
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape,
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape);
OpaqueDataBuffer hDimension,
@Cast("Nd4jLong *") LongPointer hDimensionShape,
@Cast("Nd4jLong *") LongPointer dDimensionShape);
/**
* @param opNum
@ -100,38 +94,34 @@ public interface NativeOps {
*/
void execBroadcast(PointerPointer extraPointers,
int opNum,
Pointer x,
OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer y,
OpaqueDataBuffer y,
@Cast("Nd4jLong *") LongPointer yShapeInfo,
Pointer dy,
@Cast("Nd4jLong *") LongPointer dyShapeInfo,
Pointer result,
OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape,
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape);
OpaqueDataBuffer hDimension,
@Cast("Nd4jLong *") LongPointer hDimensionShape,
@Cast("Nd4jLong *") LongPointer dDimensionShape);
void execBroadcastBool(PointerPointer extraPointers,
int opNum,
Pointer x,
OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer y,
OpaqueDataBuffer y,
@Cast("Nd4jLong *") LongPointer yShapeInfo,
Pointer dy,
@Cast("Nd4jLong *") LongPointer dyShapeInfo,
Pointer result,
OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
Pointer extraParams,
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape,
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape);
OpaqueDataBuffer hDimension,
@Cast("Nd4jLong *") LongPointer hDimensionShape,
@Cast("Nd4jLong *") LongPointer dDimensionShape);
/**
@ -146,33 +136,27 @@ public interface NativeOps {
*/
void execPairwiseTransform(PointerPointer extraPointers,
int opNum,
Pointer x,
OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer y,
OpaqueDataBuffer y,
@Cast("Nd4jLong *") LongPointer yShapeInfo,
Pointer dy,
@Cast("Nd4jLong *") LongPointer dyShapeInfo,
Pointer result,
OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
Pointer extraParams);
void execPairwiseTransformBool(PointerPointer extraPointers,
int opNum,
Pointer x,
OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer y,
OpaqueDataBuffer y,
@Cast("Nd4jLong *") LongPointer yShapeInfo,
Pointer dy,
@Cast("Nd4jLong *") LongPointer dyShapeInfo,
Pointer result,
OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
Pointer extraParams);
@ -186,53 +170,45 @@ public interface NativeOps {
*/
void execReduceFloat(PointerPointer extraPointers,
int opNum,
Pointer x,
OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer extraParams,
Pointer result,
OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo);
void execReduceSame(PointerPointer extraPointers,
int opNum,
Pointer x,
OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer extraParams,
Pointer result,
OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo);
void execReduceBool(PointerPointer extraPointers,
int opNum,
Pointer x,
OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer extraParams,
Pointer result,
OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo);
void execReduceLong(PointerPointer extraPointers,
int opNum,
Pointer x,
OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer extraParams,
Pointer result,
OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo);
/**
@ -245,60 +221,56 @@ public interface NativeOps {
*/
void execReduceFloat2(PointerPointer extraPointers,
int opNum,
Pointer x,
OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer extraParams,
Pointer result,
OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape,
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape);
OpaqueDataBuffer hDimension,
@Cast("Nd4jLong *") LongPointer hDimensionShape,
@Cast("Nd4jLong *") LongPointer dDimensionShape);
void execReduceSame2(PointerPointer extraPointers,
int opNum,
Pointer x,
OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer extraParams,
Pointer result,
OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape,
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape);
OpaqueDataBuffer hDimension,
@Cast("Nd4jLong *") LongPointer hDimensionShape,
@Cast("Nd4jLong *") LongPointer dDimensionShape);
void execReduceBool2(PointerPointer extraPointers,
int opNum,
Pointer x,
OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer extraParams,
Pointer result,
OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape,
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape);
OpaqueDataBuffer hDimension,
@Cast("Nd4jLong *") LongPointer hDimensionShape,
@Cast("Nd4jLong *") LongPointer dDimensionShape);
void execReduceLong2(PointerPointer extraPointers,
int opNum,
Pointer x,
OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer extraParams,
Pointer result,
OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape,
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape);
OpaqueDataBuffer hDimension,
@Cast("Nd4jLong *") LongPointer hDimensionShape,
@Cast("Nd4jLong *") LongPointer dDimensionShape);
/**
* @param opNum
@ -312,13 +284,16 @@ public interface NativeOps {
*/
void execReduce3(PointerPointer extraPointers,
int opNum,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer extraParamsVals,
Pointer y, @Cast("Nd4jLong *") LongPointer yShapeInfo,
Pointer dy, @Cast("Nd4jLong *") LongPointer dyShapeInfo,
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo);
OpaqueDataBuffer y,
@Cast("Nd4jLong *") LongPointer yShapeInfo,
@Cast("Nd4jLong *") LongPointer dyShapeInfo,
OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo);
/**
* @param opNum
@ -329,13 +304,16 @@ public interface NativeOps {
* @param yShapeInfo
*/
void execReduce3Scalar(PointerPointer extraPointers, int opNum,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer extraParamsVals,
Pointer y, @Cast("Nd4jLong *") LongPointer yShapeInfo,
Pointer dy, @Cast("Nd4jLong *") LongPointer dyShapeInfo,
Pointer z, @Cast("Nd4jLong *") LongPointer zShapeInfo,
Pointer dz, @Cast("Nd4jLong *") LongPointer dzShapeInfo);
OpaqueDataBuffer y,
@Cast("Nd4jLong *") LongPointer yShapeInfo,
@Cast("Nd4jLong *") LongPointer dyShapeInfo,
OpaqueDataBuffer z,
@Cast("Nd4jLong *") LongPointer zShapeInfo,
@Cast("Nd4jLong *") LongPointer dzShapeInfo);
/**
* @param opNum
@ -351,29 +329,37 @@ public interface NativeOps {
*/
void execReduce3Tad(PointerPointer extraPointers,
int opNum,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer extraParamsVals,
Pointer y, @Cast("Nd4jLong *") LongPointer yShapeInfo,
Pointer dy, @Cast("Nd4jLong *") LongPointer dyShapeInfo,
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfoBuffer,
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfoBuffer,
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape,
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape,
OpaqueDataBuffer y,
@Cast("Nd4jLong *") LongPointer yShapeInfo,
@Cast("Nd4jLong *") LongPointer dyShapeInfo,
OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfoBuffer,
@Cast("Nd4jLong *") LongPointer dresultShapeInfoBuffer,
OpaqueDataBuffer hDimension,
@Cast("Nd4jLong *") LongPointer hDimensionShape,
@Cast("Nd4jLong *") LongPointer dDimensionShape,
@Cast("Nd4jLong *") LongPointer tadOnlyShapeInfo, @Cast("Nd4jLong *") LongPointer tadOffsets,
@Cast("Nd4jLong *") LongPointer yTadOnlyShapeInfo, @Cast("Nd4jLong *") LongPointer yTadOffsets);
void execReduce3All(PointerPointer extraPointers,
int opNum,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer extraParamsVals,
Pointer y, @Cast("Nd4jLong *") LongPointer yShapeInfo,
Pointer dy, @Cast("Nd4jLong *") LongPointer dyShapeInfo,
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfoBuffer,
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfoBuffer,
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape,
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape,
OpaqueDataBuffer y,
@Cast("Nd4jLong *") LongPointer yShapeInfo,
@Cast("Nd4jLong *") LongPointer dyShapeInfo,
OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfoBuffer,
@Cast("Nd4jLong *") LongPointer dresultShapeInfoBuffer,
OpaqueDataBuffer hDimension,
@Cast("Nd4jLong *") LongPointer hDimensionShape,
@Cast("Nd4jLong *") LongPointer dDimensionShape,
@Cast("Nd4jLong *") LongPointer xTadShape,
@Cast("Nd4jLong *") LongPointer xOffsets,
@Cast("Nd4jLong *") LongPointer yTadShape,
@ -391,22 +377,28 @@ public interface NativeOps {
*/
void execScalar(PointerPointer extraPointers,
int opNum,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo,
Pointer scalar, @Cast("Nd4jLong *") LongPointer scalarShapeInfo,
Pointer dscalar, @Cast("Nd4jLong *") LongPointer dscalarShapeInfo,
OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
OpaqueDataBuffer scalar,
@Cast("Nd4jLong *") LongPointer scalarShapeInfo,
@Cast("Nd4jLong *") LongPointer dscalarShapeInfo,
Pointer extraParams);
void execScalarBool(PointerPointer extraPointers,
int opNum,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo,
Pointer scalar, @Cast("Nd4jLong *") LongPointer scalarShapeInfo,
Pointer dscalar, @Cast("Nd4jLong *") LongPointer dscalarShapeInfo,
OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
OpaqueDataBuffer scalar,
@Cast("Nd4jLong *") LongPointer scalarShapeInfo,
@Cast("Nd4jLong *") LongPointer dscalarShapeInfo,
Pointer extraParams);
/**
@ -418,11 +410,13 @@ public interface NativeOps {
*/
void execSummaryStatsScalar(PointerPointer extraPointers,
int opNum,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer extraParams,
Pointer z, @Cast("Nd4jLong *") LongPointer zShapeInfo,
Pointer dz, @Cast("Nd4jLong *") LongPointer dzShapeInfo,
OpaqueDataBuffer z,
@Cast("Nd4jLong *") LongPointer zShapeInfo,
@Cast("Nd4jLong *") LongPointer dzShapeInfo,
boolean biasCorrected);
/**
@ -436,11 +430,13 @@ public interface NativeOps {
*/
void execSummaryStats(PointerPointer extraPointers,
int opNum,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer extraParams,
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo,
OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
boolean biasCorrected);
/**
@ -455,13 +451,16 @@ public interface NativeOps {
*/
void execSummaryStatsTad(PointerPointer extraPointers,
int opNum,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer extraParams,
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfoBuffer,
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfoBuffer,
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape,
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape,
OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfoBuffer,
@Cast("Nd4jLong *") LongPointer dresultShapeInfoBuffer,
OpaqueDataBuffer hDimension,
@Cast("Nd4jLong *") LongPointer hDimensionShape,
@Cast("Nd4jLong *") LongPointer dDimensionShape,
boolean biasCorrected,
@Cast("Nd4jLong *") LongPointer tadShapeInfo,
@Cast("Nd4jLong *") LongPointer tadOffsets);
@ -478,42 +477,52 @@ public interface NativeOps {
*/
void execTransformFloat(PointerPointer extraPointers,
int opNum,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo,
OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
Pointer extraParams);
void execTransformSame(PointerPointer extraPointers,
int opNum,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo,
OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
Pointer extraParams);
void execTransformStrict(PointerPointer extraPointers,
int opNum,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo,
OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
Pointer extraParams);
void execTransformBool(PointerPointer extraPointers,
int opNum,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo,
OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
Pointer extraParams);
void execTransformAny(PointerPointer extraPointers,
int opNum,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo,
OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
OpaqueDataBuffer result,
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
Pointer extraParams);
/**
@ -532,31 +541,43 @@ public interface NativeOps {
*/
void execScalarTad(PointerPointer extraPointers,
int opNum,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer z, @Cast("Nd4jLong *") LongPointer zShapeInfo,
Pointer dz, @Cast("Nd4jLong *") LongPointer dzShapeInfo,
Pointer scalars, @Cast("Nd4jLong *") LongPointer scalarShapeInfo,
Pointer dscalars, @Cast("Nd4jLong *") LongPointer dscalarShapeInfo,
OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
OpaqueDataBuffer z,
@Cast("Nd4jLong *") LongPointer zShapeInfo,
@Cast("Nd4jLong *") LongPointer dzShapeInfo,
OpaqueDataBuffer scalars,
@Cast("Nd4jLong *") LongPointer scalarShapeInfo,
@Cast("Nd4jLong *") LongPointer dscalarShapeInfo,
Pointer extraParams,
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape,
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape,
@Cast("Nd4jLong *") LongPointer tadShapeInfo, @Cast("Nd4jLong *") LongPointer tadOffsets,
@Cast("Nd4jLong *") LongPointer tadShapeInfoZ, @Cast("Nd4jLong *") LongPointer tadOffsetsZ);
OpaqueDataBuffer hDimension,
@Cast("Nd4jLong *") LongPointer hDimensionShape,
@Cast("Nd4jLong *") LongPointer dDimensionShape,
@Cast("Nd4jLong *") LongPointer tadShapeInfo,
@Cast("Nd4jLong *") LongPointer tadOffsets,
@Cast("Nd4jLong *") LongPointer tadShapeInfoZ,
@Cast("Nd4jLong *") LongPointer tadOffsetsZ);
void execScalarBoolTad(PointerPointer extraPointers,
int opNum,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer z, @Cast("Nd4jLong *") LongPointer zShapeInfo,
Pointer dz, @Cast("Nd4jLong *") LongPointer dzShapeInfo,
Pointer scalars, @Cast("Nd4jLong *") LongPointer scalarShapeInfo,
Pointer dscalars, @Cast("Nd4jLong *") LongPointer dscalarShapeInfo,
OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
OpaqueDataBuffer z,
@Cast("Nd4jLong *") LongPointer zShapeInfo,
@Cast("Nd4jLong *") LongPointer dzShapeInfo,
OpaqueDataBuffer scalars,
@Cast("Nd4jLong *") LongPointer scalarShapeInfo,
@Cast("Nd4jLong *") LongPointer dscalarShapeInfo,
Pointer extraParams,
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape,
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape,
@Cast("Nd4jLong *") LongPointer tadShapeInfo, @Cast("Nd4jLong *") LongPointer tadOffsets,
@Cast("Nd4jLong *") LongPointer tadShapeInfoZ, @Cast("Nd4jLong *") LongPointer tadOffsetsZ);
OpaqueDataBuffer hDimension,
@Cast("Nd4jLong *") LongPointer hDimensionShape,
@Cast("Nd4jLong *") LongPointer dDimensionShape,
@Cast("Nd4jLong *") LongPointer tadShapeInfo,
@Cast("Nd4jLong *") LongPointer tadOffsets,
@Cast("Nd4jLong *") LongPointer tadShapeInfoZ,
@Cast("Nd4jLong *") LongPointer tadOffsetsZ);
void specialConcat(PointerPointer extraPointers,
@ -675,10 +696,12 @@ public interface NativeOps {
///////////////
void pullRows(PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
Pointer z, @Cast("Nd4jLong *") LongPointer zShapeInfo,
Pointer dz, @Cast("Nd4jLong *") LongPointer dzShapeInfo,
OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeInfo,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
OpaqueDataBuffer z,
@Cast("Nd4jLong *") LongPointer zShapeInfo,
@Cast("Nd4jLong *") LongPointer dzShapeInfo,
long n,
@Cast("Nd4jLong *") LongPointer indexes,
@Cast("Nd4jLong *") LongPointer tadShapeInfo,
@ -777,28 +800,34 @@ public interface NativeOps {
void execRandom(PointerPointer extraPointers,
int opNum,
Pointer state,
Pointer z, @Cast("Nd4jLong *") LongPointer zShapeBuffer,
Pointer dz, @Cast("Nd4jLong *") LongPointer dzShapeBuffer,
OpaqueDataBuffer z,
@Cast("Nd4jLong *") LongPointer zShapeBuffer,
@Cast("Nd4jLong *") LongPointer dzShapeBuffer,
Pointer extraArguments);
void execRandom3(PointerPointer extraPointers,
int opNum,
Pointer state,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeBuffer,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeBuffer,
Pointer y, @Cast("Nd4jLong *") LongPointer yShapeBuffer,
Pointer dy, @Cast("Nd4jLong *") LongPointer dyShapeBuffer,
Pointer z, @Cast("Nd4jLong *") LongPointer zShapeBuffer,
Pointer dz, @Cast("Nd4jLong *") LongPointer dzShapeBuffer,
OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeBuffer,
@Cast("Nd4jLong *") LongPointer dxShapeBuffer,
OpaqueDataBuffer y,
@Cast("Nd4jLong *") LongPointer yShapeBuffer,
@Cast("Nd4jLong *") LongPointer dyShapeBuffer,
OpaqueDataBuffer z,
@Cast("Nd4jLong *") LongPointer zShapeBuffer,
@Cast("Nd4jLong *") LongPointer dzShapeBuffer,
Pointer extraArguments);
void execRandom2(PointerPointer extraPointers,
int opNum,
Pointer state,
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeBuffer,
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeBuffer,
Pointer z, @Cast("Nd4jLong *") LongPointer zShapeBuffer,
Pointer dz, @Cast("Nd4jLong *") LongPointer dzShapeBuffer,
OpaqueDataBuffer x,
@Cast("Nd4jLong *") LongPointer xShapeBuffer,
@Cast("Nd4jLong *") LongPointer dxShapeBuffer,
OpaqueDataBuffer z,
@Cast("Nd4jLong *") LongPointer zShapeBuffer,
@Cast("Nd4jLong *") LongPointer dzShapeBuffer,
Pointer extraArguments);
////////////////////
@ -967,9 +996,11 @@ public interface NativeOps {
void tear(PointerPointer extras,
Pointer tensor, @Cast("Nd4jLong *") LongPointer xShapeInfo,
Pointer dtensor, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
PointerPointer targets, @Cast("Nd4jLong *") LongPointer zShapeInfo,
OpaqueDataBuffer tensor,
@Cast("Nd4jLong *") LongPointer xShapeInfo,
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
PointerPointer targets,
@Cast("Nd4jLong *") LongPointer zShapeInfo,
@Cast("Nd4jLong *") LongPointer tadShapeInfo,
@Cast("Nd4jLong *") LongPointer tadOffsets);
@ -1121,6 +1152,8 @@ public interface NativeOps {
void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer);
void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
void setGraphContextOutputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
void setGraphContextInputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer databuffer, Pointer shapeInfo, Pointer specialShapeInfo);
void setGraphContextOutputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer databuffer, Pointer shapeInfo, Pointer specialShapeInfo);
void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments);
void setGraphContextIArguments(OpaqueContext ptr, LongPointer arguments, int numberOfArguments);
void setGraphContextBArguments(OpaqueContext ptr, BooleanPointer arguments, int numberOfArguments);
@ -1162,4 +1195,27 @@ public interface NativeOps {
boolean isMinimalRequirementsMet();
boolean isOptimalRequirementsMet();
OpaqueDataBuffer allocateDataBuffer(long elements, int dataType, boolean allocateBoth);
OpaqueDataBuffer dbCreateView(OpaqueDataBuffer dataBuffer, long length, long offset);
Pointer dbPrimaryBuffer(OpaqueDataBuffer dataBuffer);
Pointer dbSpecialBuffer(OpaqueDataBuffer dataBuffer);
void dbExpandBuffer(OpaqueDataBuffer dataBuffer, long elements);
void dbAllocatePrimaryBuffer(OpaqueDataBuffer dataBuffer);
void dbAllocateSpecialBuffer(OpaqueDataBuffer dataBuffer);
void dbSetPrimaryBuffer(OpaqueDataBuffer dataBuffer, Pointer primaryBuffer, long numBytes);
void dbSetSpecialBuffer(OpaqueDataBuffer dataBuffer, Pointer specialBuffer, long numBytes);
void dbSyncToSpecial(OpaqueDataBuffer dataBuffer);
void dbSyncToPrimary(OpaqueDataBuffer dataBuffer);
void dbTickHostRead(OpaqueDataBuffer dataBuffer);
void dbTickHostWrite(OpaqueDataBuffer dataBuffer);
void dbTickDeviceRead(OpaqueDataBuffer dataBuffer);
void dbTickDeviceWrite(OpaqueDataBuffer dataBuffer);
void deleteDataBuffer(OpaqueDataBuffer dataBuffer);
void dbClose(OpaqueDataBuffer dataBuffer);
int dbLocality(OpaqueDataBuffer dataBuffer);
int dbDeviceId(OpaqueDataBuffer dataBuffer);
void dbSetDeviceId(OpaqueDataBuffer dataBuffer, int deviceId);
void dbExpand(OpaqueDataBuffer dataBuffer, long newLength);
}

View File

@ -0,0 +1,206 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.nativeblas;
import lombok.NonNull;
import lombok.val;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.buffer.DataType;
/**
* This class is a opaque pointer to InteropDataBuffer, used for Java/C++ interop related to INDArray DataBuffer
*
* @author saudet
*/
public class OpaqueDataBuffer extends Pointer {
// TODO: make this configurable
private static final int MAX_TRIES = 3;
public OpaqueDataBuffer(Pointer p) { super(p); }
/**
* This method allocates new InteropDataBuffer and returns pointer to it
* @param numElements
* @param dataType
* @param allocateBoth
* @return
*/
public static OpaqueDataBuffer allocateDataBuffer(long numElements, @NonNull DataType dataType, boolean allocateBoth) {
OpaqueDataBuffer buffer = null;
int ec = 0;
String em = null;
for (int t = 0; t < MAX_TRIES; t++) {
try {
// try to allocate data buffer
buffer = NativeOpsHolder.getInstance().getDeviceNativeOps().allocateDataBuffer(numElements, dataType.toInt(), allocateBoth);
// check error code
ec = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode();
if (ec != 0) {
if (em == null)
em = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage();
// if allocation failed it might be caused by casual OOM, so we'll try GC
System.gc();
} else {
// just return the buffer
return buffer;
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
// if MAX_TRIES is over, we'll just throw an exception
throw new RuntimeException("Allocation failed: [" + em + "]");
}
/**
* This method expands buffer, and copies content to the new buffer
*
* PLEASE NOTE: if InteropDataBuffer doesn't own actual buffers - original pointers won't be released
* @param numElements
*/
public void expand(long numElements) {
int ec = 0;
String em = null;
for (int t = 0; t < MAX_TRIES; t++) {
try {
// try to expand the buffer
NativeOpsHolder.getInstance().getDeviceNativeOps().dbExpand(this, numElements);
// check error code
ec = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode();
if (ec != 0) {
if (em == null)
em = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage();
// if expansion failed it might be caused by casual OOM, so we'll try GC
System.gc();
} else {
// just return
return;
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
// if MAX_TRIES is over, we'll just throw an exception
throw new RuntimeException("DataBuffer expansion failed: [" + em + "]");
}
/**
* This method creates a view out of this InteropDataBuffer
*
* @param bytesLength
* @param bytesOffset
* @return
*/
public OpaqueDataBuffer createView(long bytesLength, long bytesOffset) {
OpaqueDataBuffer buffer = null;
int ec = 0;
String em = null;
for (int t = 0; t < MAX_TRIES; t++) {
try {
buffer = NativeOpsHolder.getInstance().getDeviceNativeOps().dbCreateView(this, bytesLength, bytesOffset);
// check error code
ec = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode();
if (ec != 0) {
if (em == null)
em = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage();
// if view creation failed it might be caused by casual OOM, so we'll try GC
System.gc();
} else {
// just return
return buffer;
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
// if MAX_TRIES is over, we'll just throw an exception
throw new RuntimeException("DataBuffer expansion failed: [" + em + "]");
}
/**
* This method returns pointer to linear buffer, primary one.
* @return
*/
public Pointer primaryBuffer() {
return NativeOpsHolder.getInstance().getDeviceNativeOps().dbPrimaryBuffer(this);
}
/**
* This method returns pointer to special buffer, device one, if any.
* @return
*/
public Pointer specialBuffer() {
return NativeOpsHolder.getInstance().getDeviceNativeOps().dbSpecialBuffer(this);
}
/**
* This method returns deviceId of this DataBuffer
* @return
*/
public int deviceId() {
return NativeOpsHolder.getInstance().getDeviceNativeOps().dbDeviceId(this);
}
/**
* This method allows to set external pointer as primary buffer.
*
* PLEASE NOTE: if InteropDataBuffer owns current memory buffer, it will be released
* @param ptr
* @param numElements
*/
public void setPrimaryBuffer(Pointer ptr, long numElements) {
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetPrimaryBuffer(this, ptr, numElements);
}
/**
* This method allows to set external pointer as primary buffer.
*
* PLEASE NOTE: if InteropDataBuffer owns current memory buffer, it will be released
* @param ptr
* @param numElements
*/
public void setSpecialBuffer(Pointer ptr, long numElements) {
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetSpecialBuffer(this, ptr, numElements);
}
/**
* This method synchronizes device memory
*/
public void syncToSpecial() {
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToSpecial(this);
}
/**
* This method synchronizes host memory
*/
public void syncToPrimary() {
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToPrimary(this);
}
}

View File

@ -253,6 +253,7 @@
<version>${cuda.version}-${cudnn.version}-${javacpp-presets.cuda.version}</version>
<classifier>${dependency.platform}</classifier>
</dependency>
<!--
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>libnd4j</artifactId>
@ -261,6 +262,7 @@
<classifier>${javacpp.platform}-cuda-${cuda.version}</classifier>
<scope>provided</scope>
</dependency>
-->
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>

View File

@ -19,6 +19,7 @@ package org.nd4j.jita.allocator.impl;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import lombok.val;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.garbage.GarbageBufferReference;
@ -29,9 +30,11 @@ import org.nd4j.jita.allocator.time.providers.MillisecondsProvider;
import org.nd4j.jita.allocator.time.providers.OperativeProvider;
import org.nd4j.linalg.api.buffer.BaseDataBuffer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.nativeblas.OpaqueDataBuffer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -54,8 +57,8 @@ import java.util.concurrent.locks.ReentrantLock;
public class AllocationPoint {
private static Logger log = LoggerFactory.getLogger(AllocationPoint.class);
// thread safety is guaranteed by cudaLock
private volatile PointersPair pointerInfo;
@Getter
private OpaqueDataBuffer ptrDataBuffer;
@Getter
@Setter
@ -104,33 +107,27 @@ public class AllocationPoint {
*/
private volatile int deviceId;
public AllocationPoint() {
//
private long bytes;
public AllocationPoint(@NonNull OpaqueDataBuffer opaqueDataBuffer, long bytes) {
ptrDataBuffer = opaqueDataBuffer;
this.bytes = bytes;
objectId = Nd4j.getDeallocatorService().nextValue();
}
public void acquireLock() {
//lock.lock();
}
public void releaseLock() {
//lock.unlock();
public void setPointers(Pointer primary, Pointer special, long numberOfElements) {
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetPrimaryBuffer(ptrDataBuffer, primary, numberOfElements);
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetSpecialBuffer(ptrDataBuffer, special, numberOfElements);
}
public int getDeviceId() {
return deviceId;
return ptrDataBuffer.deviceId();
}
public void setDeviceId(int deviceId) {
this.deviceId = deviceId;
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetDeviceId(ptrDataBuffer, deviceId);
}
/*
We assume 1D memory chunk allocations.
*/
@Getter
@Setter
private AllocationShape shape;
private AtomicBoolean enqueued = new AtomicBoolean(false);
@Getter
@ -164,7 +161,7 @@ public class AllocationPoint {
}
public long getNumberOfBytes() {
return shape.getNumberOfBytes();
return bytes;
}
/*
@ -220,67 +217,25 @@ public class AllocationPoint {
* This method returns CUDA pointer object for this allocation.
* It can be either device pointer or pinned memory pointer, or null.
*
* PLEASE NOTE: Thread safety is guaranteed by reentrant read/write lock
* @return
*/
public Pointer getDevicePointer() {
if (pointerInfo == null) {
log.info("pointerInfo is null");
return null;
}
return pointerInfo.getDevicePointer();
return NativeOpsHolder.getInstance().getDeviceNativeOps().dbSpecialBuffer(ptrDataBuffer);
}
/**
* This method returns CUDA pointer object for this allocation.
* It can be either device pointer or pinned memory pointer, or null.
*
* PLEASE NOTE: Thread safety is guaranteed by reentrant read/write lock
* @return
*/
public Pointer getHostPointer() {
if (pointerInfo == null)
return null;
return pointerInfo.getHostPointer();
}
/**
* This method sets CUDA pointer for this allocation.
* It can be either device pointer, or pinned memory pointer, or null.
*
* PLEASE NOTE: Thread safety is guaranteed by reentrant read/write lock
* @param pointerInfo CUDA pointers wrapped into DevicePointerInfo
*/
public void setPointers(@NonNull PointersPair pointerInfo) {
this.pointerInfo = pointerInfo;
}
public PointersPair getPointers() {
return this.pointerInfo;
return NativeOpsHolder.getInstance().getDeviceNativeOps().dbPrimaryBuffer(ptrDataBuffer);
}
public synchronized void tickDeviceRead() {
// this.deviceTicks.incrementAndGet();
// this.timerShort.triggerEvent();
// this.timerLong.triggerEvent();
//this.deviceAccessTime.set(realTimeProvider.getCurrentTime());
this.accessDeviceRead = (timeProvider.getCurrentTime());
}
/**
* Returns time, in milliseconds, when this point was accessed on host side
*
* @return
*/
public synchronized long getHostReadTime() {
return accessHostRead;
};
public synchronized long getHostWriteTime() {
return accessHostWrite;
NativeOpsHolder.getInstance().getDeviceNativeOps().dbTickDeviceRead(ptrDataBuffer);
}
/**
@ -302,7 +257,7 @@ public class AllocationPoint {
}
public synchronized void tickHostRead() {
accessHostRead = (timeProvider.getCurrentTime());
NativeOpsHolder.getInstance().getDeviceNativeOps().dbTickHostRead(ptrDataBuffer);
}
/**
@ -310,17 +265,14 @@ public class AllocationPoint {
*
*/
public synchronized void tickDeviceWrite() {
// deviceAccessTime.set(realTimeProvider.getCurrentTime());
tickDeviceRead();
accessDeviceWrite = (timeProvider.getCurrentTime());
NativeOpsHolder.getInstance().getDeviceNativeOps().dbTickDeviceWrite(ptrDataBuffer);
}
/**
* This method sets time when this point was changed on host
*/
public synchronized void tickHostWrite() {
tickHostRead();
accessHostWrite = (timeProvider.getCurrentTime());
NativeOpsHolder.getInstance().getDeviceNativeOps().dbTickHostWrite(ptrDataBuffer);
}
/**
@ -329,10 +281,8 @@ public class AllocationPoint {
* @return true, if data is actual, false otherwise
*/
public synchronized boolean isActualOnHostSide() {
boolean result = accessHostWrite >= accessDeviceWrite
|| accessHostRead >= accessDeviceWrite;
return result;
val s = NativeOpsHolder.getInstance().getDeviceNativeOps().dbLocality(ptrDataBuffer);
return s <= 0;
}
/**
@ -341,9 +291,8 @@ public class AllocationPoint {
* @return
*/
public synchronized boolean isActualOnDeviceSide() {
boolean result = accessDeviceWrite >= accessHostWrite
|| accessDeviceRead >= accessHostWrite;
return result;
val s = NativeOpsHolder.getInstance().getDeviceNativeOps().dbLocality(ptrDataBuffer);
return s >= 0;
}
/**
@ -355,6 +304,6 @@ public class AllocationPoint {
@Override
public String toString() {
return "AllocationPoint{" + "deviceId=" + deviceId + ", objectId=" + objectId + ", shape=" + shape + '}';
return "AllocationPoint{" + "deviceId=" + deviceId + ", objectId=" + objectId + "}";
}
}

View File

@ -19,12 +19,10 @@ package org.nd4j.jita.allocator.impl;
import lombok.Getter;
import lombok.NonNull;
import lombok.val;
import org.apache.commons.lang3.RandomUtils;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.Allocator;
import org.nd4j.jita.allocator.enums.Aggressiveness;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.garbage.GarbageBufferReference;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.PointersPair;
import org.nd4j.jita.allocator.time.Ring;
@ -37,29 +35,25 @@ import org.nd4j.jita.flow.FlowController;
import org.nd4j.jita.handler.MemoryHandler;
import org.nd4j.jita.handler.impl.CudaZeroHandler;
import org.nd4j.jita.workspace.CudaWorkspace;
import org.nd4j.linalg.api.buffer.BaseDataBuffer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.Utf8Buffer;
import org.nd4j.linalg.api.memory.enums.MemoryKind;
import org.nd4j.linalg.api.memory.pointers.PagedPointer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.cache.ConstantHandler;
import org.nd4j.linalg.compression.CompressedDataBuffer;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaUtf8Buffer;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer;
import java.lang.ref.ReferenceQueue;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.LockSupport;
import java.util.concurrent.locks.ReentrantReadWriteLock;
/**
@ -285,16 +279,10 @@ public class AtomicAllocator implements Allocator {
*/
@Override
public Pointer getPointer(@NonNull DataBuffer buffer, CudaContext context) {
if (buffer instanceof Utf8Buffer)
return null;
return memoryHandler.getDevicePointer(buffer, context);
}
public Pointer getPointer(DataBuffer buffer) {
if (buffer instanceof Utf8Buffer)
return null;
return memoryHandler.getDevicePointer(buffer, getDeviceContext());
}
@ -320,7 +308,7 @@ public class AtomicAllocator implements Allocator {
public Pointer getPointer(INDArray array, CudaContext context) {
// DataBuffer buffer = array.data().originalDataBuffer() == null ? array.data() : array.data().originalDataBuffer();
if (array.isEmpty() || array.isS())
return null;
throw new UnsupportedOperationException("Pew-pew");
return memoryHandler.getDevicePointer(array.data(), context);
}
@ -372,20 +360,17 @@ public class AtomicAllocator implements Allocator {
@Override
public void synchronizeHostData(DataBuffer buffer) {
// we don't want non-committed ops left behind
//Nd4j.getExecutioner().push();
Nd4j.getExecutioner().commit();
// we don't synchronize constant buffers, since we assume they are always valid on host side
if (buffer.isConstant() || buffer.dataType() == DataType.UTF8 || AtomicAllocator.getInstance().getAllocationPoint(buffer).getPointers().getHostPointer() == null) {
return;
}
val oPtr = NativeOpsHolder.getInstance().getDeviceNativeOps().dbPrimaryBuffer(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer());
// we actually need synchronization only in device-dependant environment. no-op otherwise
if (memoryHandler.isDeviceDependant()) {
val point = getAllocationPoint(buffer.getTrackingPoint());
if (point == null)
throw new RuntimeException("AllocationPoint is NULL");
memoryHandler.synchronizeThreadDevice(Thread.currentThread().getId(), memoryHandler.getDeviceId(), point);
}
// we actually need synchronization only in device-dependant environment. no-op otherwise. managed by native code
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToPrimary(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer());
val cPtr = NativeOpsHolder.getInstance().getDeviceNativeOps().dbPrimaryBuffer(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer());
//assert oPtr.address() == cPtr.address();
//assert buffer.address() == oPtr.address();
}
@ -446,6 +431,7 @@ public class AtomicAllocator implements Allocator {
public AllocationPoint pickExternalBuffer(DataBuffer buffer) {
/**
AllocationPoint point = new AllocationPoint();
Long allocId = objectsTracker.getAndIncrement();
point.setObjectId(allocId);
@ -458,6 +444,9 @@ public class AtomicAllocator implements Allocator {
point.tickHostRead();
return point;
*/
throw new UnsupportedOperationException("Pew-pew");
}
/**
@ -469,69 +458,8 @@ public class AtomicAllocator implements Allocator {
* @param location
*/
@Override
public AllocationPoint allocateMemory(DataBuffer buffer, AllocationShape requiredMemory, AllocationStatus location,
boolean initialize) {
AllocationPoint point = new AllocationPoint();
useTracker.set(System.currentTimeMillis());
// we use these longs as tracking codes for memory tracking
Long allocId = objectsTracker.getAndIncrement();
//point.attachBuffer(buffer);
point.setObjectId(allocId);
point.setShape(requiredMemory);
/*
if (buffer instanceof CudaIntDataBuffer) {
buffer.setConstant(true);
point.setConstant(true);
}
*/
/*int numBuckets = configuration.getNumberOfGcThreads();
int bucketId = RandomUtils.nextInt(0, numBuckets);
GarbageBufferReference reference =
new GarbageBufferReference((BaseDataBuffer) buffer, queueMap.get(bucketId), point);*/
//point.attachReference(reference);
point.setDeviceId(-1);
if (buffer.isAttached()) {
long reqMem = AllocationUtils.getRequiredMemory(requiredMemory);
// workaround for init order
getMemoryHandler().getCudaContext();
point.setDeviceId(Nd4j.getAffinityManager().getDeviceForCurrentThread());
val workspace = (CudaWorkspace) Nd4j.getMemoryManager().getCurrentWorkspace();
val pair = new PointersPair();
val ptrDev = workspace.alloc(reqMem, MemoryKind.DEVICE, requiredMemory.getDataType(), initialize);
if (ptrDev != null) {
pair.setDevicePointer(ptrDev);
point.setAllocationStatus(AllocationStatus.DEVICE);
} else {
// we allocate initial host pointer only
val ptrHost = workspace.alloc(reqMem, MemoryKind.HOST, requiredMemory.getDataType(), initialize);
pair.setHostPointer(ptrHost);
pair.setDevicePointer(ptrHost);
point.setAllocationStatus(AllocationStatus.HOST);
}
point.setAttached(true);
point.setPointers(pair);
} else {
// we stay naive on PointersPair, we just don't know on this level, which pointers are set. MemoryHandler will be used for that
PointersPair pair = memoryHandler.alloc(location, point, requiredMemory, initialize);
point.setPointers(pair);
}
allocationsMap.put(allocId, point);
//point.tickHostRead();
point.tickDeviceWrite();
//point.setAllocationStatus(location);
return point;
public AllocationPoint allocateMemory(DataBuffer buffer, AllocationShape requiredMemory, AllocationStatus location, boolean initialize) {
throw new UnsupportedOperationException("Pew-pew");
}
@ -619,10 +547,11 @@ public class AtomicAllocator implements Allocator {
*/
if (point.getBuffer() == null) {
purgeZeroObject(bucketId, object, point, false);
freeSpace.addAndGet(AllocationUtils.getRequiredMemory(point.getShape()));
//freeSpace.addAndGet(AllocationUtils.getRequiredMemory(point.getShape()));
throw new UnsupportedOperationException("Pew-pew");
elementsDropped.incrementAndGet();
continue;
//elementsDropped.incrementAndGet();
//continue;
} else {
elementsSurvived.incrementAndGet();
}
@ -682,13 +611,14 @@ public class AtomicAllocator implements Allocator {
if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
// we deallocate device memory
purgeDeviceObject(threadId, deviceId, object, point, false);
freeSpace.addAndGet(AllocationUtils.getRequiredMemory(point.getShape()));
//freeSpace.addAndGet(AllocationUtils.getRequiredMemory(point.getShape()));
// and we deallocate host memory, since object is dereferenced
purgeZeroObject(point.getBucketId(), object, point, false);
//purgeZeroObject(point.getBucketId(), object, point, false);
elementsDropped.incrementAndGet();
continue;
//elementsDropped.incrementAndGet();
//continue;
throw new UnsupportedOperationException("Pew-pew");
} ;
} else {
elementsSurvived.incrementAndGet();
@ -1014,6 +944,31 @@ public class AtomicAllocator implements Allocator {
this.memoryHandler.memcpy(dstBuffer, srcBuffer);
}
@Override
public void tickHostWrite(DataBuffer buffer) {
getAllocationPoint(buffer).tickHostWrite();
}
@Override
public void tickHostWrite(INDArray array) {
getAllocationPoint(array.data()).tickHostWrite();
}
@Override
public void tickDeviceWrite(INDArray array) {
getAllocationPoint(array.data()).tickDeviceWrite();
}
@Override
public AllocationPoint getAllocationPoint(INDArray array) {
return getAllocationPoint(array.data());
}
@Override
public AllocationPoint getAllocationPoint(DataBuffer buffer) {
return ((BaseCudaDataBuffer) buffer).getAllocationPoint();
}
/**
* This method returns deviceId for current thread
* All values >= 0 are considered valid device IDs, all values < 0 are considered stubs.
@ -1031,48 +986,6 @@ public class AtomicAllocator implements Allocator {
return new CudaPointer(getDeviceId());
}
@Override
public void tickHostWrite(DataBuffer buffer) {
AllocationPoint point = getAllocationPoint(buffer.getTrackingPoint());
point.tickHostWrite();
}
@Override
public void tickHostWrite(INDArray array) {
DataBuffer buffer =
array.data().originalDataBuffer() == null ? array.data() : array.data().originalDataBuffer();
tickHostWrite(buffer);
}
@Override
public void tickDeviceWrite(INDArray array) {
DataBuffer buffer =
array.data().originalDataBuffer() == null ? array.data() : array.data().originalDataBuffer();
AllocationPoint point = getAllocationPoint(buffer.getTrackingPoint());
point.tickDeviceWrite();
}
@Override
public AllocationPoint getAllocationPoint(INDArray array) {
if (array.isEmpty())
return null;
DataBuffer buffer = array.data().originalDataBuffer() == null ? array.data() : array.data().originalDataBuffer();
return getAllocationPoint(buffer);
}
@Override
public AllocationPoint getAllocationPoint(DataBuffer buffer) {
if (buffer instanceof CompressedDataBuffer) {
log.warn("Trying to get AllocationPoint from CompressedDataBuffer");
throw new RuntimeException("AP CDB");
}
return getAllocationPoint(buffer.getTrackingPoint());
}
@Override
public void registerAction(CudaContext context, INDArray result, INDArray... operands) {
memoryHandler.registerAction(context, result, operands);

View File

@ -23,46 +23,21 @@ import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.nd4j.linalg.api.memory.Deallocator;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.nativeblas.OpaqueDataBuffer;
@Slf4j
public class CudaDeallocator implements Deallocator {
private AllocationPoint point;
private OpaqueDataBuffer opaqueDataBuffer;
public CudaDeallocator(@NonNull BaseCudaDataBuffer buffer) {
this.point = buffer.getAllocationPoint();
if (this.point == null)
throw new RuntimeException();
opaqueDataBuffer = buffer.getOpaqueDataBuffer();
}
@Override
public void deallocate() {
log.trace("Deallocating CUDA memory");
// skipping any allocation that is coming from workspace
if (point.isAttached() || point.isReleased()) {
// TODO: remove allocation point as well?
if (!AtomicAllocator.getInstance().allocationsMap().containsKey(point.getObjectId()))
return;
AtomicAllocator.getInstance().getFlowController().waitTillReleased(point);
AtomicAllocator.getInstance().getFlowController().getEventsProvider().storeEvent(point.getLastWriteEvent());
AtomicAllocator.getInstance().getFlowController().getEventsProvider().storeEvent(point.getLastReadEvent());
AtomicAllocator.getInstance().allocationsMap().remove(point.getObjectId());
return;
}
//log.info("Purging {} bytes...", AllocationUtils.getRequiredMemory(point.getShape()));
if (point.getAllocationStatus() == AllocationStatus.HOST) {
AtomicAllocator.getInstance().purgeZeroObject(point.getBucketId(), point.getObjectId(), point, false);
} else if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
AtomicAllocator.getInstance().purgeDeviceObject(0L, point.getDeviceId(), point.getObjectId(), point, false);
// and we deallocate host memory, since object is dereferenced
AtomicAllocator.getInstance().purgeZeroObject(point.getBucketId(), point.getObjectId(), point, false);
}
NativeOpsHolder.getInstance().getDeviceNativeOps().deleteDataBuffer(opaqueDataBuffer);
}
}

View File

@ -17,6 +17,7 @@
package org.nd4j.jita.allocator.pointers.cuda;
import lombok.NonNull;
import lombok.val;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.linalg.exception.ND4JException;
@ -37,8 +38,9 @@ public class cudaStream_t extends CudaPointer {
NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
int res = nativeOps.streamSynchronize(this);
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
val ec = nativeOps.lastErrorCode();
if (ec != 0)
throw new RuntimeException(nativeOps.lastErrorMessage() + "; Error code: " + ec);
return res;
}

View File

@ -129,7 +129,7 @@ public class ProtectedCudaConstantHandler implements ConstantHandler {
AllocationPoint point = AtomicAllocator.getInstance().getAllocationPoint(dataBuffer);
long requiredMemoryBytes = AllocationUtils.getRequiredMemory(point.getShape());
long requiredMemoryBytes = point.getNumberOfBytes();
val originalBytes = requiredMemoryBytes;
requiredMemoryBytes += 8 - (requiredMemoryBytes % 8);
@ -147,13 +147,13 @@ public class ProtectedCudaConstantHandler implements ConstantHandler {
if (currentOffset + requiredMemoryBytes >= MAX_CONSTANT_LENGTH || requiredMemoryBytes > MAX_BUFFER_LENGTH) {
if (point.getAllocationStatus() == AllocationStatus.HOST
&& CudaEnvironment.getInstance().getConfiguration().getMemoryModel() == Configuration.MemoryModel.DELAYED) {
AtomicAllocator.getInstance().getMemoryHandler().alloc(AllocationStatus.DEVICE, point, point.getShape(),
false);
//AtomicAllocator.getInstance().getMemoryHandler().alloc(AllocationStatus.DEVICE, point, point.getShape(), false);
throw new UnsupportedOperationException("Pew-pew");
}
val profD = PerformanceTracker.getInstance().helperStartTransaction();
if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(point.getPointers().getDevicePointer(), point.getPointers().getHostPointer(), originalBytes, 1, context.getSpecialStream()) == 0) {
if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(point.getDevicePointer(), point.getHostPointer(), originalBytes, 1, context.getSpecialStream()) == 0) {
throw new ND4JIllegalStateException("memcpyAsync failed");
}
flowController.commitTransfer(context.getSpecialStream());
@ -176,14 +176,13 @@ public class ProtectedCudaConstantHandler implements ConstantHandler {
if (currentOffset >= MAX_CONSTANT_LENGTH) {
if (point.getAllocationStatus() == AllocationStatus.HOST
&& CudaEnvironment.getInstance().getConfiguration().getMemoryModel() == Configuration.MemoryModel.DELAYED) {
AtomicAllocator.getInstance().getMemoryHandler().alloc(AllocationStatus.DEVICE, point, point.getShape(),
false);
//AtomicAllocator.getInstance().getMemoryHandler().alloc(AllocationStatus.DEVICE, point, point.getShape(), false);
throw new UnsupportedOperationException("Pew-pew");
}
val profD = PerformanceTracker.getInstance().helperStartTransaction();
if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(point.getPointers().getDevicePointer(), point.getPointers().getHostPointer(),
originalBytes, 1, context.getSpecialStream()) == 0) {
if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(point.getDevicePointer(), point.getHostPointer(), originalBytes, 1, context.getSpecialStream()) == 0) {
throw new ND4JIllegalStateException("memcpyAsync failed");
}
flowController.commitTransfer(context.getSpecialStream());
@ -202,8 +201,7 @@ public class ProtectedCudaConstantHandler implements ConstantHandler {
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyConstantAsync(currentOffset, point.getPointers().getHostPointer(), originalBytes, 1,
context.getSpecialStream());
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyConstantAsync(currentOffset, point.getHostPointer(), originalBytes, 1, context.getSpecialStream());
flowController.commitTransfer(context.getSpecialStream());
long cAddr = deviceAddresses.get(deviceId).address() + currentOffset;
@ -212,7 +210,10 @@ public class ProtectedCudaConstantHandler implements ConstantHandler {
// logger.info("copying to constant: {}, bufferLength: {}, bufferDtype: {}, currentOffset: {}, currentAddres: {}", requiredMemoryBytes, dataBuffer.length(), dataBuffer.dataType(), currentOffset, cAddr);
point.setAllocationStatus(AllocationStatus.CONSTANT);
point.getPointers().setDevicePointer(new CudaPointer(cAddr));
//point.setDevicePointer(new CudaPointer(cAddr));
if (1 > 0)
throw new UnsupportedOperationException("Pew-pew");
point.setConstant(true);
point.tickDeviceWrite();
point.setDeviceId(deviceId);

View File

@ -32,6 +32,7 @@ import org.nd4j.jita.conf.Configuration;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.jita.flow.FlowController;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
@ -70,53 +71,12 @@ public class SynchronousFlowController implements FlowController {
*/
@Override
public void synchronizeToHost(AllocationPoint point) {
if (!point.isActualOnHostSide()) {
val context = allocator.getDeviceContext();
if (!point.isConstant())
waitTillFinished(point);
// if this piece of memory is device-dependant, we'll also issue copyback once
if (point.getAllocationStatus() == AllocationStatus.DEVICE && !point.isActualOnHostSide()) {
long perfD = PerformanceTracker.getInstance().helperStartTransaction();
val bytes = AllocationUtils.getRequiredMemory(point.getShape());
if (nativeOps.memcpyAsync(point.getHostPointer(), point.getDevicePointer(), bytes, CudaConstants.cudaMemcpyDeviceToHost, context.getSpecialStream()) == 0)
throw new IllegalStateException("synchronizeToHost memcpyAsync failed: " + point.getShape());
commitTransfer(context.getSpecialStream());
PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), perfD, point.getNumberOfBytes(), MemcpyDirection.DEVICE_TO_HOST);
}
// updating host read timer
point.tickHostRead();
}
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToPrimary(point.getPtrDataBuffer());
}
@Override
public void synchronizeToDevice(@NonNull AllocationPoint point) {
if (point.isConstant())
return;
if (!point.isActualOnDeviceSide()) {
if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
val context = allocator.getDeviceContext();
long perfD = PerformanceTracker.getInstance().helperStartTransaction();
if (nativeOps.memcpyAsync(point.getDevicePointer(), point.getHostPointer(),
AllocationUtils.getRequiredMemory(point.getShape()),
CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream()) == 0)
throw new IllegalStateException("MemcpyAsync failed: " + point.getShape());
commitTransfer(context.getSpecialStream());
point.tickDeviceRead();
PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), perfD, point.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);
}
}
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToSpecial(point.getPtrDataBuffer());
}
@Override
@ -147,7 +107,6 @@ public class SynchronousFlowController implements FlowController {
val pointData = allocator.getAllocationPoint(operand);
val pointShape = allocator.getAllocationPoint(operand.shapeInfoDataBuffer());
pointData.acquireLock();
if (pointData.getDeviceId() != cId && pointData.getDeviceId() >= 0) {
DataBuffer buffer = operand.data().originalDataBuffer() == null ? operand.data()
@ -172,15 +131,12 @@ public class SynchronousFlowController implements FlowController {
val cId = allocator.getDeviceId();
if (result != null && !result.isEmpty() && !result.isS()) {
if (result != null && !result.isEmpty()) {
Nd4j.getCompressor().autoDecompress(result);
prepareDelayedMemory(result);
val pointData = allocator.getAllocationPoint(result);
val pointShape = allocator.getAllocationPoint(result.shapeInfoDataBuffer());
pointData.acquireLock();
if (pointData.getDeviceId() != cId && pointData.getDeviceId() >= 0 && (!CudaEnvironment.getInstance().getConfiguration().isCrossDeviceAccessAllowed() || !NativeOpsHolder.getInstance().getDeviceNativeOps().isP2PAvailable())) {
DataBuffer buffer = result.data().originalDataBuffer() == null ? result.data()
: result.data().originalDataBuffer();
@ -206,8 +162,7 @@ public class SynchronousFlowController implements FlowController {
val pointData = allocator.getAllocationPoint(operand);
val pointShape = allocator.getAllocationPoint(operand.shapeInfoDataBuffer());
pointData.acquireLock();
Nd4j.getAffinityManager().ensureLocation(operand, AffinityManager.Location.DEVICE);
if (pointData.getDeviceId() != cId && pointData.getDeviceId() >= 0 && (!CudaEnvironment.getInstance().getConfiguration().isCrossDeviceAccessAllowed() || !NativeOpsHolder.getInstance().getDeviceNativeOps().isP2PAvailable())) {
DataBuffer buffer = operand.data().originalDataBuffer() == null ? operand.data()
@ -240,14 +195,12 @@ public class SynchronousFlowController implements FlowController {
eventsProvider.storeEvent(result.getLastWriteEvent());
result.setLastWriteEvent(eventsProvider.getEvent());
result.getLastWriteEvent().register(context.getOldStream());
result.releaseLock();
for (AllocationPoint operand : operands) {
eventsProvider.storeEvent(operand.getLastReadEvent());
operand.setLastReadEvent(eventsProvider.getEvent());
operand.getLastReadEvent().register(context.getOldStream());
operand.releaseLock();
}
// context.syncOldStream();
}
@ -263,7 +216,6 @@ public class SynchronousFlowController implements FlowController {
eventsProvider.storeEvent(pointOperand.getLastWriteEvent());
pointOperand.setLastWriteEvent(eventsProvider.getEvent());
pointOperand.getLastWriteEvent().register(context.getOldStream());
pointOperand.releaseLock();
}
}
@ -276,14 +228,12 @@ public class SynchronousFlowController implements FlowController {
eventsProvider.storeEvent(point.getLastWriteEvent());
point.setLastWriteEvent(eventsProvider.getEvent());
point.getLastWriteEvent().register(context.getOldStream());
point.releaseLock();
for (INDArray operand : operands) {
if (operand == null || operand.isEmpty())
continue;
val pointOperand = allocator.getAllocationPoint(operand);
pointOperand.releaseLock();
eventsProvider.storeEvent(pointOperand.getLastReadEvent());
pointOperand.setLastReadEvent(eventsProvider.getEvent());
pointOperand.getLastReadEvent().register(context.getOldStream());
@ -295,7 +245,6 @@ public class SynchronousFlowController implements FlowController {
val context = allocator.getDeviceContext();
if (result != null) {
result.acquireLock();
result.setCurrentContext(context);
}
@ -303,7 +252,6 @@ public class SynchronousFlowController implements FlowController {
if (operand == null)
continue;
operand.acquireLock();
operand.setCurrentContext(context);
}

View File

@ -16,6 +16,7 @@
package org.nd4j.jita.handler.impl;
import lombok.var;
import org.nd4j.nativeblas.OpaqueLaunchContext;
import org.nd4j.shade.guava.collect.HashBasedTable;
import org.nd4j.shade.guava.collect.Table;
@ -44,9 +45,6 @@ import org.nd4j.jita.flow.FlowController;
import org.nd4j.jita.flow.impl.GridFlowController;
import org.nd4j.jita.handler.MemoryHandler;
import org.nd4j.jita.memory.MemoryProvider;
import org.nd4j.jita.memory.impl.CudaCachingZeroProvider;
import org.nd4j.jita.memory.impl.CudaDirectProvider;
import org.nd4j.jita.memory.impl.CudaFullCachingProvider;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
@ -99,9 +97,6 @@ public class CudaZeroHandler implements MemoryHandler {
private final AtomicBoolean wasInitialised = new AtomicBoolean(false);
@Getter
private final MemoryProvider memoryProvider;
private final FlowController flowController;
private final AllocationStatus INITIAL_LOCATION;
@ -148,20 +143,6 @@ public class CudaZeroHandler implements MemoryHandler {
throw new RuntimeException("Unknown ExecutionModel: [" + configuration.getExecutionModel() + "]");
}
switch (configuration.getAllocationModel()) {
case CACHE_ALL:
this.memoryProvider = new CudaFullCachingProvider();
break;
case CACHE_HOST:
this.memoryProvider = new CudaCachingZeroProvider();
break;
case DIRECT:
this.memoryProvider = new CudaDirectProvider();
break;
default:
throw new RuntimeException("Unknown AllocationModel: [" + configuration.getAllocationModel() + "]");
}
int numDevices = NativeOpsHolder.getInstance().getDeviceNativeOps().getAvailableDevices();
for (int i = 0; i < numDevices; i++) {
deviceAllocations.add(new ConcurrentHashMap<Long, Long>());
@ -191,7 +172,7 @@ public class CudaZeroHandler implements MemoryHandler {
int numBuckets = configuration.getNumberOfGcThreads();
long bucketId = RandomUtils.nextInt(0, numBuckets);
long reqMemory = AllocationUtils.getRequiredMemory(point.getShape());
long reqMemory = point.getNumberOfBytes();
zeroUseCounter.addAndGet(reqMemory);
@ -221,130 +202,7 @@ public class CudaZeroHandler implements MemoryHandler {
public PointersPair alloc(AllocationStatus targetMode, AllocationPoint point, AllocationShape shape,
boolean initialize) {
long reqMemory = AllocationUtils.getRequiredMemory(shape);
val context = getCudaContext();
switch (targetMode) {
case HOST: {
if (MemoryTracker.getInstance().getActiveHostAmount() + reqMemory >= configuration.getMaximumZeroAllocation()) {
while (MemoryTracker.getInstance().getActiveHostAmount() + reqMemory >= configuration.getMaximumZeroAllocation()) {
val before = MemoryTracker.getInstance().getActiveHostAmount();
memoryProvider.purgeCache();
Nd4j.getMemoryManager().invokeGc();
val after = MemoryTracker.getInstance().getActiveHostAmount();
log.debug("[HOST] before: {}; after: {};", before, after);
if (MemoryTracker.getInstance().getActiveHostAmount() + reqMemory >= configuration.getMaximumZeroAllocation()) {
try {
log.warn("No available [HOST] memory, sleeping for a while... Consider increasing -Xmx next time.");
log.debug("Currently used: [" + zeroUseCounter.get() + "], allocated objects: [" + zeroAllocations.get(0) + "]");
memoryProvider.purgeCache();
Nd4j.getMemoryManager().invokeGc();
Thread.sleep(1000);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
}
PointersPair pair = memoryProvider.malloc(shape, point, targetMode);
if (initialize) {
org.bytedeco.javacpp.Pointer.memset(pair.getHostPointer(), 0, reqMemory);
point.tickHostWrite();
}
pickupHostAllocation(point);
return pair;
}
case DEVICE: {
int deviceId = getDeviceId();
PointersPair returnPair = new PointersPair();
PointersPair tmpPair = new PointersPair();
if (point.getPointers() == null)
point.setPointers(tmpPair);
if (deviceMemoryTracker.reserveAllocationIfPossible(Thread.currentThread().getId(), deviceId, reqMemory)) {
point.setDeviceId(deviceId);
val pair = memoryProvider.malloc(shape, point, targetMode);
if (pair != null) {
returnPair.setDevicePointer(pair.getDevicePointer());
point.setAllocationStatus(AllocationStatus.DEVICE);
if (point.getPointers() == null)
throw new RuntimeException("PointersPair can't be null");
point.getPointers().setDevicePointer(pair.getDevicePointer());
deviceAllocations.get(deviceId).put(point.getObjectId(), point.getObjectId());
val p = point.getBucketId();
if (p != null) {
val m = zeroAllocations.get(point.getBucketId());
// m can be null, if that's point from workspace - just no bucketId for it
if (m != null)
m.remove(point.getObjectId());
}
deviceMemoryTracker.addToAllocation(Thread.currentThread().getId(), deviceId, reqMemory);
if (!initialize) {
point.tickDeviceWrite();
} else {
nativeOps.memsetAsync(pair.getDevicePointer(), 0, reqMemory, 0, context.getSpecialStream());
context.getSpecialStream().synchronize();
point.tickDeviceWrite();
}
} else {
log.warn("Out of [DEVICE] memory, host memory will be used instead: deviceId: [{}], requested bytes: [{}]; Approximate free bytes: {}; Real free bytes: {}", deviceId, reqMemory, MemoryTracker.getInstance().getApproximateFreeMemory(deviceId), MemoryTracker.getInstance().getPreciseFreeMemory(deviceId));
log.info("Total allocated dev_0: {}", MemoryTracker.getInstance().getActiveMemory(0));
log.info("Cached dev_0: {}", MemoryTracker.getInstance().getCachedAmount(0));
log.info("Allocated dev_0: {}", MemoryTracker.getInstance().getAllocatedAmount(0));
log.info("Workspace dev_0: {}", MemoryTracker.getInstance().getWorkspaceAllocatedAmount(0));
//log.info("Total allocated dev_1: {}", MemoryTracker.getInstance().getActiveMemory(1));
// if device memory allocation failed (aka returned NULL), keep using host memory instead
returnPair.setDevicePointer(tmpPair.getHostPointer());
point.setAllocationStatus(AllocationStatus.HOST);
Nd4j.getMemoryManager().invokeGc();
try {
Thread.sleep(100);
} catch (Exception e) {
}
}
} else {
log.warn("Hard limit on [DEVICE] memory hit, please consider tuning memory parameters, deviceId [{}]",
deviceId);
Nd4j.getMemoryManager().invokeGc();
try {
Thread.sleep(100);
} catch (InterruptedException e) {
//
}
}
return returnPair;
}
default:
throw new IllegalStateException("Can't allocate memory on target [" + targetMode + "]");
}
throw new UnsupportedOperationException();
}
/**
@ -356,7 +214,7 @@ public class CudaZeroHandler implements MemoryHandler {
*/
@Override
public boolean pingDeviceForFreeMemory(Integer deviceId, long requiredMemory) {
return memoryProvider.pingDeviceForFreeMemory(deviceId, requiredMemory);
return true;
}
/**
@ -371,47 +229,7 @@ public class CudaZeroHandler implements MemoryHandler {
@Override
public void relocate(AllocationStatus currentStatus, AllocationStatus targetStatus, AllocationPoint point,
AllocationShape shape, CudaContext context) {
//log.info("RELOCATE CALLED: [" +currentStatus+ "] -> ["+targetStatus+"]");
if (currentStatus == AllocationStatus.DEVICE && targetStatus == AllocationStatus.HOST) {
// DEVICE -> HOST
DataBuffer targetBuffer = point.getBuffer();
if (targetBuffer == null)
throw new IllegalStateException("Target buffer is NULL!");
Pointer devicePointer = new CudaPointer(point.getPointers().getDevicePointer().address());
} else if (currentStatus == AllocationStatus.HOST && targetStatus == AllocationStatus.DEVICE) {
// HOST -> DEVICE
// TODO: this probably should be removed
if (point.isConstant()) {
//log.info("Skipping relocation for constant");
return;
}
if (point.getPointers().getDevicePointer() == null) {
throw new IllegalStateException("devicePointer is NULL!");
}
val profD = PerformanceTracker.getInstance().helperStartTransaction();
if (nativeOps.memcpyAsync(point.getPointers().getDevicePointer(), point.getPointers().getHostPointer(),
AllocationUtils.getRequiredMemory(shape), CudaConstants.cudaMemcpyHostToDevice,
context.getSpecialStream()) == 0)
throw new IllegalStateException("MemcpyAsync relocate H2D failed: [" + point.getHostPointer().address()
+ "] -> [" + point.getDevicePointer().address() + "]");
flowController.commitTransfer(context.getSpecialStream());
PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), profD, point.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);
//context.syncOldStream();
} else
throw new UnsupportedOperationException("Can't relocate data in requested direction: [" + currentStatus
+ "] -> [" + targetStatus + "]");
}
/**
@ -440,11 +258,6 @@ public class CudaZeroHandler implements MemoryHandler {
@Override
@Deprecated
public void copyforward(AllocationPoint point, AllocationShape shape) {
/*
Technically that's just a case for relocate, with source as HOST and target point.getAllocationStatus()
*/
log.info("copyforward() called on tp[" + point.getObjectId() + "], shape: " + point.getShape());
//relocate(AllocationStatus.HOST, point.getAllocationStatus(), point, shape);
throw new UnsupportedOperationException("Deprecated call");
}
@ -467,15 +280,7 @@ public class CudaZeroHandler implements MemoryHandler {
*/
@Override
public void free(AllocationPoint point, AllocationStatus target) {
//if (point.getAllocationStatus() == AllocationStatus.DEVICE)
//deviceAllocations.get(point.getDeviceId()).remove(point.getObjectId());
//zeroAllocations.get(point.getBucketId()).remove(point.getObjectId());
if (point.getAllocationStatus() == AllocationStatus.DEVICE)
deviceMemoryTracker.subFromAllocation(Thread.currentThread().getId(), point.getDeviceId(),
AllocationUtils.getRequiredMemory(point.getShape()));
memoryProvider.free(point);
}
/**
@ -525,7 +330,7 @@ public class CudaZeroHandler implements MemoryHandler {
CudaContext tContext = null;
if (dstBuffer.isConstant()) {
org.bytedeco.javacpp.Pointer dstPointer = new CudaPointer(point.getPointers().getHostPointer().address() + dstOffset, 0L);
org.bytedeco.javacpp.Pointer dstPointer = new CudaPointer(point.getHostPointer().address() + dstOffset, 0L);
org.bytedeco.javacpp.Pointer srcPointerJ = new CudaPointer(srcPointer, length);
val profD = PerformanceTracker.getInstance().helperStartTransaction();
@ -534,14 +339,34 @@ public class CudaZeroHandler implements MemoryHandler {
point.tickHostRead();
} else {
// if we're copying something into host memory, but we're on device - we need to provide exact copy to device as well
Pointer rDP = new CudaPointer(point.getDevicePointer().address() + dstOffset);
if (tContext == null)
tContext = flowController.prepareAction(point);
var prof = PerformanceTracker.getInstance().helperStartTransaction();
flowController.commitTransfer(tContext.getSpecialStream());
if (nativeOps.memcpyAsync(rDP, srcPointer, length, CudaConstants.cudaMemcpyHostToDevice, tContext.getSpecialStream()) == 0)
throw new IllegalStateException("MemcpyAsync H2D failed: [" + srcPointer.address() + "] -> [" + rDP.address() + "]");
flowController.commitTransfer(tContext.getSpecialStream());
PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), prof, point.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);
flowController.registerAction(tContext, point);
point.tickDeviceWrite();
// we optionally copy to host memory
if (point.getPointers().getHostPointer() != null) {
Pointer dP = new CudaPointer((point.getPointers().getHostPointer().address()) + dstOffset);
if (point.getHostPointer() != null) {
Pointer dP = new CudaPointer((point.getHostPointer().address()) + dstOffset);
CudaContext context = flowController.prepareAction(point);
tContext = context;
val prof = PerformanceTracker.getInstance().helperStartTransaction();
prof = PerformanceTracker.getInstance().helperStartTransaction();
if (nativeOps.memcpyAsync(dP, srcPointer, length, CudaConstants.cudaMemcpyHostToHost, context.getSpecialStream()) == 0)
throw new IllegalStateException("MemcpyAsync H2H failed: [" + srcPointer.address() + "] -> [" + dP.address() + "]");
@ -552,28 +377,10 @@ public class CudaZeroHandler implements MemoryHandler {
if (point.getAllocationStatus() == AllocationStatus.HOST)
flowController.registerAction(context, point);
point.tickHostRead();
}
}
// if we're copying something into host memory, but we're on device - we need to provide exact copy to device as well
if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
Pointer rDP = new CudaPointer(point.getPointers().getDevicePointer().address() + dstOffset);
if (tContext == null)
tContext = flowController.prepareAction(point);
val prof = PerformanceTracker.getInstance().helperStartTransaction();
if (nativeOps.memcpyAsync(rDP, srcPointer, length, CudaConstants.cudaMemcpyHostToDevice, tContext.getSpecialStream()) == 0)
throw new IllegalStateException("MemcpyAsync H2D failed: [" + srcPointer.address() + "] -> [" + rDP.address() + "]");
flowController.commitTransfer(tContext.getSpecialStream());
PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), prof, point.getNumberOfBytes(),MemcpyDirection.HOST_TO_DEVICE);
flowController.registerAction(tContext, point);
point.tickDeviceWrite();
}
}
@Override
@ -581,7 +388,7 @@ public class CudaZeroHandler implements MemoryHandler {
CudaContext context) {
AllocationPoint point = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint();
Pointer dP = new CudaPointer((point.getPointers().getDevicePointer().address()) + dstOffset);
Pointer dP = new CudaPointer((point.getDevicePointer().address()) + dstOffset);
if (nativeOps.memcpyAsync(dP, srcPointer, length, CudaConstants.cudaMemcpyDeviceToDevice, context.getOldStream()) == 0)
throw new ND4JIllegalStateException("memcpyAsync failed");
@ -604,7 +411,7 @@ public class CudaZeroHandler implements MemoryHandler {
CudaContext context = getCudaContext();
AllocationPoint point = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint();
Pointer dP = new CudaPointer((point.getPointers().getHostPointer().address()) + dstOffset);
Pointer dP = new CudaPointer((point.getHostPointer().address()) + dstOffset);
val profH = PerformanceTracker.getInstance().helperStartTransaction();
@ -614,7 +421,7 @@ public class CudaZeroHandler implements MemoryHandler {
PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), profH, point.getNumberOfBytes(),MemcpyDirection.HOST_TO_HOST);
if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
Pointer rDP = new CudaPointer(point.getPointers().getDevicePointer().address() + dstOffset);
Pointer rDP = new CudaPointer(point.getDevicePointer().address() + dstOffset);
val profD = PerformanceTracker.getInstance().helperStartTransaction();
@ -717,23 +524,22 @@ public class CudaZeroHandler implements MemoryHandler {
@Override
public org.bytedeco.javacpp.Pointer getDevicePointer(DataBuffer buffer, CudaContext context) {
// TODO: It would be awesome to get rid of typecasting here
//getCudaContext().syncOldStream();
AllocationPoint dstPoint = ((BaseCudaDataBuffer) buffer).getAllocationPoint();
// if that's device state, we probably might want to update device memory state
if (dstPoint.getAllocationStatus() == AllocationStatus.DEVICE) {
if (!dstPoint.isActualOnDeviceSide()) {
// log.info("Relocating to GPU");
relocate(AllocationStatus.HOST, AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), context);
//relocate(AllocationStatus.HOST, AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), context);
throw new UnsupportedOperationException("Pew-pew");
}
}
// we update memory use counter, to announce that it's somehow used on device
dstPoint.tickDeviceRead();
if (dstPoint.getDevicePointer() == null)
return null;
// return pointer with offset if needed. length is specified for constructor compatibility purposes
val p = new CudaPointer(dstPoint.getPointers().getDevicePointer(), buffer.length(),
(buffer.offset() * buffer.getElementSize()));
// return pointer. length is specified for constructor compatibility purposes. Offset is accounted at C++ side
val p = new CudaPointer(dstPoint.getDevicePointer(), buffer.length(), 0);
if (OpProfiler.getInstance().getConfig().isCheckLocality())
NativeOpsHolder.getInstance().getDeviceNativeOps().tryPointer(context.getOldStream(), p, 1);
@ -749,10 +555,17 @@ public class CudaZeroHandler implements MemoryHandler {
case SHORT:
case UINT16:
case HALF:
case BFLOAT16:
return p.asShortPointer();
case UINT64:
case LONG:
return p.asLongPointer();
case UTF8:
case UBYTE:
case BYTE:
return p.asBytePointer();
case BOOL:
return p.asBooleanPointer();
default:
return p;
}
@ -769,17 +582,14 @@ public class CudaZeroHandler implements MemoryHandler {
AllocationPoint dstPoint = ((BaseCudaDataBuffer) buffer).getAllocationPoint();
// return pointer with offset if needed. length is specified for constructor compatibility purposes
if (dstPoint.getPointers().getHostPointer() == null) {
if (dstPoint.getHostPointer() == null) {
return null;
}
//dstPoint.tickHostWrite();
//dstPoint.tickHostRead();
//log.info("Requesting host pointer for {}", buffer);
//getCudaContext().syncOldStream();
synchronizeThreadDevice(Thread.currentThread().getId(), dstPoint.getDeviceId(), dstPoint);
CudaPointer p = new CudaPointer(dstPoint.getPointers().getHostPointer(), buffer.length(),
(buffer.offset() * buffer.getElementSize()));
CudaPointer p = new CudaPointer(dstPoint.getHostPointer(), buffer.length(), 0);
switch (buffer.dataType()) {
case DOUBLE:
return p.asDoublePointer();
@ -805,6 +615,9 @@ public class CudaZeroHandler implements MemoryHandler {
public synchronized void relocateObject(DataBuffer buffer) {
AllocationPoint dstPoint = AtomicAllocator.getInstance().getAllocationPoint(buffer);
if (1 > 0)
throw new UnsupportedOperationException("Pew-pew");
// we don't relocate non-DEVICE buffers (i.e HOST or CONSTANT)
if (dstPoint.getAllocationStatus() != AllocationStatus.DEVICE)
return;
@ -838,14 +651,14 @@ public class CudaZeroHandler implements MemoryHandler {
// if we're out of workspace, we should mark our buffer as detached, so gc will pick it up eventually
// host part is optional
if (dstPoint.getHostPointer() != null) {
val pairH = alloc(AllocationStatus.HOST, dstPoint, dstPoint.getShape(), false);
dstPoint.getPointers().setHostPointer(pairH.getHostPointer());
//val pairH = alloc(AllocationStatus.HOST, dstPoint, dstPoint.getShape(), false);
//dstPoint.getPointers().setHostPointer(pairH.getHostPointer());
}
val pairD = alloc(AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), false);
dstPoint.getPointers().setDevicePointer(pairD.getDevicePointer());
//val pairD = alloc(AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), false);
//dstPoint.getPointers().setDevicePointer(pairD.getDevicePointer());
//log.info("New host pointer: {}; Old host pointer: {}", dstPoint.getHostPointer().address(), ohPtr.address());
////log.info("New host pointer: {}; Old host pointer: {}", dstPoint.getHostPointer().address(), ohPtr.address());
CudaContext context = getCudaContext();
@ -876,10 +689,10 @@ public class CudaZeroHandler implements MemoryHandler {
Nd4j.getMemoryManager().memcpy(nBuffer, buffer);
dstPoint.getPointers().setDevicePointer(nBuffer.getAllocationPoint().getDevicePointer());
//dstPoint.getPointers().setDevicePointer(nBuffer.getAllocationPoint().getDevicePointer());
if (dstPoint.getHostPointer() != null) {
dstPoint.getPointers().setHostPointer(nBuffer.getAllocationPoint().getHostPointer());
// dstPoint.getPointers().setHostPointer(nBuffer.getAllocationPoint().getHostPointer());
}
dstPoint.setDeviceId(deviceId);
@ -908,11 +721,10 @@ public class CudaZeroHandler implements MemoryHandler {
context.syncSpecialStream();
}
memoryProvider.free(dstPoint);
deviceMemoryTracker.subFromAllocation(Thread.currentThread().getId(), dstPoint.getDeviceId(), AllocationUtils.getRequiredMemory(dstPoint.getShape()));
//deviceMemoryTracker.subFromAllocation(Thread.currentThread().getId(), dstPoint.getDeviceId(), AllocationUtils.getRequiredMemory(dstPoint.getShape()));
// we replace original device pointer with new one
alloc(AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), false);
//alloc(AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), false);
val profD = PerformanceTracker.getInstance().helperStartTransaction();
@ -940,6 +752,9 @@ public class CudaZeroHandler implements MemoryHandler {
public boolean promoteObject(DataBuffer buffer) {
AllocationPoint dstPoint = AtomicAllocator.getInstance().getAllocationPoint(buffer);
if (1 > 0)
throw new UnsupportedOperationException("Pew-pew");
if (dstPoint.getAllocationStatus() != AllocationStatus.HOST)
return false;
@ -952,20 +767,19 @@ public class CudaZeroHandler implements MemoryHandler {
Nd4j.getConstantHandler().moveToConstantSpace(buffer);
} else {
PointersPair pair = memoryProvider.malloc(dstPoint.getShape(), dstPoint, AllocationStatus.DEVICE);
PointersPair pair = null; //memoryProvider.malloc(dstPoint.getShape(), dstPoint, AllocationStatus.DEVICE);
if (pair != null) {
Integer deviceId = getDeviceId();
// log.info("Promoting object to device: [{}]", deviceId);
dstPoint.getPointers().setDevicePointer(pair.getDevicePointer());
//dstPoint.setDevicePointer(pair.getDevicePointer());
dstPoint.setAllocationStatus(AllocationStatus.DEVICE);
deviceAllocations.get(deviceId).put(dstPoint.getObjectId(), dstPoint.getObjectId());
zeroAllocations.get(dstPoint.getBucketId()).remove(dstPoint.getObjectId());
deviceMemoryTracker.addToAllocation(Thread.currentThread().getId(), deviceId,
AllocationUtils.getRequiredMemory(dstPoint.getShape()));
//deviceMemoryTracker.addToAllocation(Thread.currentThread().getId(), deviceId, AllocationUtils.getRequiredMemory(dstPoint.getShape()));
dstPoint.tickHostWrite();
@ -1103,7 +917,7 @@ public class CudaZeroHandler implements MemoryHandler {
if (deviceAllocations.get(deviceId).containsKey(objectId))
throw new IllegalStateException("Can't happen ever");
deviceMemoryTracker.subFromAllocation(threadId, deviceId, AllocationUtils.getRequiredMemory(point.getShape()));
//deviceMemoryTracker.subFromAllocation(threadId, deviceId, AllocationUtils.getRequiredMemory(point.getShape()));
point.setAllocationStatus(AllocationStatus.HOST);
@ -1119,6 +933,9 @@ public class CudaZeroHandler implements MemoryHandler {
*/
@Override
public void purgeZeroObject(Long bucketId, Long objectId, AllocationPoint point, boolean copyback) {
if (1 > 0)
throw new UnsupportedOperationException("Pew-pew");
forget(point, AllocationStatus.HOST);
flowController.waitTillReleased(point);
@ -1127,8 +944,8 @@ public class CudaZeroHandler implements MemoryHandler {
if (point.getHostPointer() != null) {
free(point, AllocationStatus.HOST);
long reqMem = AllocationUtils.getRequiredMemory(point.getShape()) * -1;
zeroUseCounter.addAndGet(reqMem);
//long reqMem = AllocationUtils.getRequiredMemory(point.getShape()) * -1;
//zeroUseCounter.addAndGet(reqMem);
}
point.setAllocationStatus(AllocationStatus.DEALLOCATED);
@ -1252,4 +1069,9 @@ public class CudaZeroHandler implements MemoryHandler {
public FlowController getFlowController() {
return flowController;
}
@Override
public MemoryProvider getMemoryProvider() {
return null;
}
}

View File

@ -147,7 +147,7 @@ public class CudaMemoryManager extends BasicMemoryManager {
// Nd4j.getShapeInfoProvider().purgeCache();
// purge memory cache
AtomicAllocator.getInstance().getMemoryHandler().getMemoryProvider().purgeCache();
//AtomicAllocator.getInstance().getMemoryHandler().getMemoryProvider().purgeCache();
}

View File

@ -1,303 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.jita.memory.impl;
import lombok.val;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AllocationShape;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.PointersPair;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.jita.conf.Configuration;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.jita.memory.MemoryProvider;
import org.slf4j.Logger;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import org.nd4j.jita.allocator.impl.MemoryTracker;
/**
* This is MemoryProvider implementation, that adds cache for memory reuse purposes. Only host memory is cached for future reuse.
*
* If some memory chunk gets released via allocator, it'll be probably saved for future reused within same JVM process.
*
* @author raver119@gmail.com
*/
public class CudaCachingZeroProvider extends CudaDirectProvider implements MemoryProvider {
private static Logger log = LoggerFactory.getLogger(CudaCachingZeroProvider.class);
protected volatile ConcurrentHashMap<AllocationShape, CacheHolder> zeroCache = new ConcurrentHashMap<>();
protected final AtomicLong cacheZeroHit = new AtomicLong(0);
protected final AtomicLong cacheZeroMiss = new AtomicLong(0);
protected final AtomicLong cacheDeviceHit = new AtomicLong(0);
protected final AtomicLong cacheDeviceMiss = new AtomicLong(0);
private final AtomicLong allocRequests = new AtomicLong(0);
protected final AtomicLong zeroCachedAmount = new AtomicLong(0);
protected List<AtomicLong> deviceCachedAmount = new ArrayList<>();
protected final Semaphore singleLock = new Semaphore(1);
// we don't cache allocations greater then this value
//protected final long MAX_SINGLE_ALLOCATION = configuration.getMaximumHostCacheableLength();
// maximum cached size of memory
//protected final long MAX_CACHED_MEMORY = configuration.getMaximumHostCache();
// memory chunks below this threshold will be guaranteed regardless of number of cache entries
// that especially covers all possible variations of shapeInfoDataBuffers in all possible cases
protected final long FORCED_CACHE_THRESHOLD = 96;
// number of preallocation entries for each yet-unknown shape
//protected final int PREALLOCATION_LIMIT = configuration.getPreallocationCalls();
public CudaCachingZeroProvider() {
}
/**
* This method provides PointersPair to memory chunk specified by AllocationShape
*
* PLEASE NOTE: This method can actually ignore malloc request, and give out previously cached free memory chunk with equal shape.
*
* @param shape shape of desired memory chunk
* @param point target AllocationPoint structure
* @param location either HOST or DEVICE
* @return
*/
@Override
public PointersPair malloc(AllocationShape shape, AllocationPoint point, AllocationStatus location) {
long reqMemory = AllocationUtils.getRequiredMemory(shape);
if (location == AllocationStatus.HOST && reqMemory < CudaEnvironment.getInstance().getConfiguration().getMaximumHostCacheableLength()) {
val cache = zeroCache.get(shape);
if (cache != null) {
val pointer = cache.poll();
if (pointer != null) {
cacheZeroHit.incrementAndGet();
// since this memory chunk is going to be used now, remove it's amount from
zeroCachedAmount.addAndGet(-1 * reqMemory);
val pair = new PointersPair();
pair.setDevicePointer(new CudaPointer(pointer.address()));
pair.setHostPointer(new CudaPointer(pointer.address()));
point.setAllocationStatus(AllocationStatus.HOST);
MemoryTracker.getInstance().incrementAllocatedHostAmount(reqMemory);
MemoryTracker.getInstance().decrementCachedHostAmount(reqMemory);
return pair;
}
}
cacheZeroMiss.incrementAndGet();
if (CudaEnvironment.getInstance().getConfiguration().isUsePreallocation() && zeroCachedAmount.get() < CudaEnvironment.getInstance().getConfiguration().getMaximumHostCache() / 10
&& reqMemory < 16 * 1024 * 1024L) {
val preallocator = new CachePreallocator(shape, location, CudaEnvironment.getInstance().getConfiguration().getPreallocationCalls());
preallocator.start();
}
cacheZeroMiss.incrementAndGet();
return super.malloc(shape, point, location);
}
return super.malloc(shape, point, location);
}
protected void ensureCacheHolder(AllocationShape shape) {
if (!zeroCache.containsKey(shape)) {
try {
singleLock.acquire();
if (!zeroCache.containsKey(shape)) {
zeroCache.put(shape, new CacheHolder(shape, zeroCachedAmount));
}
} catch (Exception e) {
throw new RuntimeException(e);
} finally {
singleLock.release();
}
}
}
/**
* This method frees specific chunk of memory, described by AllocationPoint passed in.
*
* PLEASE NOTE: This method can actually ignore free, and keep released memory chunk for future reuse.
*
* @param point
*/
@Override
public void free(AllocationPoint point) {
if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
super.free(point);
} else {
// if this point has no allocated chunk - step over it
if (point.getHostPointer() == null)
return;
AllocationShape shape = point.getShape();
long reqMemory = AllocationUtils.getRequiredMemory(shape);
// we don't cache too big objects
if (reqMemory > CudaEnvironment.getInstance().getConfiguration().getMaximumHostCacheableLength() || zeroCachedAmount.get() >= CudaEnvironment.getInstance().getConfiguration().getMaximumHostCache()) {
super.free(point);
return;
}
ensureCacheHolder(shape);
/*
Now we should decide if this object can be cached or not
*/
CacheHolder cache = zeroCache.get(shape);
// memory chunks < threshold will be cached no matter what
if (reqMemory <= FORCED_CACHE_THRESHOLD) {
Pointer.memset(point.getHostPointer(), 0, reqMemory);
cache.put(new CudaPointer(point.getHostPointer().address()));
} else {
long cacheEntries = cache.size();
long cacheHeight = zeroCache.size();
// total memory allocated within this bucket
long cacheDepth = cacheEntries * reqMemory;
Pointer.memset(point.getHostPointer(), 0, reqMemory);
cache.put(new CudaPointer(point.getHostPointer().address()));
}
MemoryTracker.getInstance().decrementAllocatedHostAmount(reqMemory);
MemoryTracker.getInstance().incrementCachedHostAmount(reqMemory);
}
}
private float getZeroCacheHitRatio() {
long totalHits = cacheZeroHit.get() + cacheZeroMiss.get();
float cacheRatio = cacheZeroHit.get() * 100 / (float) totalHits;
return cacheRatio;
}
private float getDeviceCacheHitRatio() {
long totalHits = cacheDeviceHit.get() + cacheDeviceMiss.get();
float cacheRatio = cacheDeviceHit.get() * 100 / (float) totalHits;
return cacheRatio;
}
@Deprecated
public void printCacheStats() {
log.debug("Cached host amount: " + zeroCachedAmount.get());
log.debug("Cached device amount: " + deviceCachedAmount.get(0).get());
log.debug("Total shapes in cache: " + zeroCache.size());
log.debug("Current host hit ratio: " + getZeroCacheHitRatio());
log.debug("Current device hit ratio: " + getDeviceCacheHitRatio());
}
protected class CacheHolder {
private Queue<Pointer> queue = new ConcurrentLinkedQueue<>();
private volatile int counter = 0;
private long reqMem = 0;
private final AtomicLong allocCounter;
public CacheHolder(AllocationShape shape, AtomicLong counter) {
this.reqMem = AllocationUtils.getRequiredMemory(shape);
this.allocCounter = counter;
}
public synchronized int size() {
return counter;
}
public synchronized Pointer poll() {
val pointer = queue.poll();
if (pointer != null)
counter--;
return pointer;
}
public synchronized void put(Pointer pointer) {
allocCounter.addAndGet(reqMem);
counter++;
queue.add(pointer);
}
}
protected class CachePreallocator extends Thread implements Runnable {
private AllocationShape shape;
private AllocationStatus location;
private int target;
public CachePreallocator(AllocationShape shape, AllocationStatus location, int numberOfEntries) {
this.shape = shape;
this.target = numberOfEntries;
this.location = location;
}
@Override
public void run() {
ensureCacheHolder(shape);
for (int i = 0; i < target; i++) {
val point = new AllocationPoint();
val pair = CudaCachingZeroProvider.super.malloc(shape, point, this.location);
if (this.location == AllocationStatus.HOST) {
Pointer pointer = new CudaPointer(pair.getHostPointer().address());
CudaCachingZeroProvider.this.zeroCache.get(shape).put(pointer);
}
}
}
}
@Override
public void purgeCache() {
for (AllocationShape shape : zeroCache.keySet()) {
Pointer ptr = null;
while ((ptr = zeroCache.get(shape).poll()) != null) {
freeHost(ptr);
MemoryTracker.getInstance().decrementCachedHostAmount(shape.getNumberOfBytes());
}
}
zeroCachedAmount.set(0);
}
}

View File

@ -1,239 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.jita.memory.impl;
import lombok.val;
import lombok.var;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AllocationShape;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.PointersPair;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.jita.memory.MemoryProvider;
import org.nd4j.linalg.api.memory.AllocationsTracker;
import org.nd4j.linalg.api.memory.enums.AllocationKind;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.nd4j.jita.allocator.impl.MemoryTracker;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
/**
* @author raver119@gmail.com
*/
public class CudaDirectProvider implements MemoryProvider {
protected static final long DEVICE_RESERVED_SPACE = 1024 * 1024 * 50L;
private static Logger log = LoggerFactory.getLogger(CudaDirectProvider.class);
protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
protected volatile ConcurrentHashMap<Long, Integer> validator = new ConcurrentHashMap<>();
private AtomicLong emergencyCounter = new AtomicLong(0);
/**
* This method provides PointersPair to memory chunk specified by AllocationShape
*
* @param shape shape of desired memory chunk
* @param point target AllocationPoint structure
* @param location either HOST or DEVICE
* @return
*/
@Override
public PointersPair malloc(AllocationShape shape, AllocationPoint point, AllocationStatus location) {
//log.info("shape onCreate: {}, target: {}", shape, location);
switch (location) {
case HOST: {
long reqMem = AllocationUtils.getRequiredMemory(shape);
// FIXME: this is WRONG, and directly leads to memleak
if (reqMem < 1)
reqMem = 1;
val pointer = nativeOps.mallocHost(reqMem, 0);
if (pointer == null)
throw new RuntimeException("Can't allocate [HOST] memory: " + reqMem + "; threadId: "
+ Thread.currentThread().getId());
// log.info("Host allocation, Thread id: {}, ReqMem: {}, Pointer: {}", Thread.currentThread().getId(), reqMem, pointer != null ? pointer.address() : null);
val hostPointer = new CudaPointer(pointer);
val devicePointerInfo = new PointersPair();
if (point.getPointers().getDevicePointer() == null) {
point.setAllocationStatus(AllocationStatus.HOST);
devicePointerInfo.setDevicePointer(new CudaPointer(hostPointer, reqMem));
} else
devicePointerInfo.setDevicePointer(point.getDevicePointer());
devicePointerInfo.setHostPointer(new CudaPointer(hostPointer, reqMem));
point.setPointers(devicePointerInfo);
MemoryTracker.getInstance().incrementAllocatedHostAmount(reqMem);
return devicePointerInfo;
}
case DEVICE: {
// cudaMalloc call
val deviceId = AtomicAllocator.getInstance().getDeviceId();
long reqMem = AllocationUtils.getRequiredMemory(shape);
// FIXME: this is WRONG, and directly leads to memleak
if (reqMem < 1)
reqMem = 1;
AllocationsTracker.getInstance().markAllocated(AllocationKind.GENERAL, deviceId, reqMem);
var pointer = nativeOps.mallocDevice(reqMem, deviceId, 0);
if (pointer == null) {
// try to purge stuff if we're low on memory
purgeCache(deviceId);
// call for gc
Nd4j.getMemoryManager().invokeGc();
pointer = nativeOps.mallocDevice(reqMem, deviceId, 0);
if (pointer == null)
return null;
}
val devicePointer = new CudaPointer(pointer);
var devicePointerInfo = point.getPointers();
if (devicePointerInfo == null)
devicePointerInfo = new PointersPair();
devicePointerInfo.setDevicePointer(new CudaPointer(devicePointer, reqMem));
point.setAllocationStatus(AllocationStatus.DEVICE);
point.setDeviceId(deviceId);
MemoryTracker.getInstance().incrementAllocatedAmount(deviceId, reqMem);
return devicePointerInfo;
}
default:
throw new IllegalStateException("Unsupported location for malloc: [" + location + "]");
}
}
/**
* This method frees specific chunk of memory, described by AllocationPoint passed in
*
* @param point
*/
@Override
public void free(AllocationPoint point) {
switch (point.getAllocationStatus()) {
case HOST: {
// cudaFreeHost call here
long reqMem = AllocationUtils.getRequiredMemory(point.getShape());
val nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
long result = nativeOps.freeHost(point.getPointers().getHostPointer());
if (result == 0) {
throw new RuntimeException("Can't deallocate [HOST] memory...");
}
MemoryTracker.getInstance().decrementAllocatedHostAmount(reqMem);
}
break;
case DEVICE: {
if (point.isConstant())
return;
long reqMem = AllocationUtils.getRequiredMemory(point.getShape());
val nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
AllocationsTracker.getInstance().markReleased(AllocationKind.GENERAL, point.getDeviceId(), reqMem);
val pointers = point.getPointers();
long result = nativeOps.freeDevice(pointers.getDevicePointer(), 0);
if (result == 0)
throw new RuntimeException("Can't deallocate [DEVICE] memory...");
MemoryTracker.getInstance().decrementAllocatedAmount(point.getDeviceId(), reqMem);
}
break;
default:
throw new IllegalStateException("Can't free memory on target [" + point.getAllocationStatus() + "]");
}
}
/**
* This method checks specified device for specified amount of memory
*
* @param deviceId
* @param requiredMemory
* @return
*/
public boolean pingDeviceForFreeMemory(Integer deviceId, long requiredMemory) {
/*
long[] totalMem = new long[1];
long[] freeMem = new long[1];
JCuda.cudaMemGetInfo(freeMem, totalMem);
long free = freeMem[0];
long total = totalMem[0];
long used = total - free;
/*
We don't want to allocate memory if it's too close to the end of available ram.
*/
//if (configuration != null && used > total * configuration.getMaxDeviceMemoryUsed()) return false;
/*
if (free + requiredMemory < total * 0.85)
return true;
else return false;
*/
long freeMem = nativeOps.getDeviceFreeMemory(-1);
if (freeMem - requiredMemory < DEVICE_RESERVED_SPACE)
return false;
else
return true;
}
protected void freeHost(Pointer pointer) {
val nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
nativeOps.freeHost(pointer);
}
protected void freeDevice(Pointer pointer, int deviceId) {
val nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
nativeOps.freeDevice(pointer, 0);
}
protected void purgeCache(int deviceId) {
//
}
@Override
public void purgeCache() {
// no-op
}
}

View File

@ -1,220 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.jita.memory.impl;
import lombok.val;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AllocationShape;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.impl.MemoryTracker;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.PointersPair;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
/**
* This MemoryProvider implementation does caching for both host and device memory within predefined limits.
*
* @author raver119@gmail.com
*/
public class CudaFullCachingProvider extends CudaCachingZeroProvider {
//protected final long MAX_GPU_ALLOCATION = configuration.getMaximumSingleDeviceAllocation();
//protected final long MAX_GPU_CACHE = configuration.getMaximumDeviceCache();
protected volatile ConcurrentHashMap<Integer, ConcurrentHashMap<AllocationShape, CacheHolder>> deviceCache =
new ConcurrentHashMap<>();
private static Logger log = LoggerFactory.getLogger(CudaFullCachingProvider.class);
public CudaFullCachingProvider() {
init();
}
public void init() {
int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
deviceCachedAmount = new ArrayList<>();
for (int i = 0; i < numDevices; i++) {
deviceCachedAmount.add(new AtomicLong(0));
}
}
/**
* This method provides PointersPair to memory chunk specified by AllocationShape
*
* PLEASE NOTE: This method can actually ignore malloc request, and give out previously cached free memory chunk with equal shape.
*
* @param shape shape of desired memory chunk
* @param point target AllocationPoint structure
* @param location either HOST or DEVICE
* @return
*/
@Override
public PointersPair malloc(AllocationShape shape, AllocationPoint point, AllocationStatus location) {
val reqMemory = AllocationUtils.getRequiredMemory(shape);
if (location == AllocationStatus.DEVICE && reqMemory < CudaEnvironment.getInstance().getConfiguration().getMaximumDeviceAllocation()) {
val deviceId = AtomicAllocator.getInstance().getDeviceId();
ensureDeviceCacheHolder(deviceId, shape);
val cache = deviceCache.get(deviceId).get(shape);
if (cache != null) {
val pointer = cache.poll();
if (pointer != null) {
cacheDeviceHit.incrementAndGet();
deviceCachedAmount.get(deviceId).addAndGet(-reqMemory);
val pair = new PointersPair();
pair.setDevicePointer(pointer);
point.setAllocationStatus(AllocationStatus.DEVICE);
point.setDeviceId(deviceId);
MemoryTracker.getInstance().incrementAllocatedAmount(deviceId, reqMemory);
MemoryTracker.getInstance().decrementCachedAmount(deviceId, reqMemory);
return pair;
}
}
cacheDeviceMiss.incrementAndGet();
return super.malloc(shape, point, location);
}
return super.malloc(shape, point, location);
}
/**
* This method frees specific chunk of memory, described by AllocationPoint passed in
*
* PLEASE NOTE: This method can actually ignore free, and keep released memory chunk for future reuse.
*
* @param point
*/
@Override
public void free(AllocationPoint point) {
if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
if (point.isConstant())
return;
val shape = point.getShape();
val deviceId = point.getDeviceId();
val address = point.getDevicePointer().address();
val reqMemory = AllocationUtils.getRequiredMemory(shape);
// we don't cache too big objects
if (reqMemory > CudaEnvironment.getInstance().getConfiguration().getMaximumDeviceCacheableLength() || deviceCachedAmount.get(deviceId).get() >= CudaEnvironment.getInstance().getConfiguration().getMaximumDeviceCache()) {
super.free(point);
return;
}
ensureDeviceCacheHolder(deviceId, shape);
val cache = deviceCache.get(deviceId).get(shape);
if (point.getDeviceId() != deviceId)
throw new RuntimeException("deviceId changed!");
// memory chunks < threshold will be cached no matter what
if (reqMemory <= FORCED_CACHE_THRESHOLD) {
cache.put(new CudaPointer(point.getDevicePointer().address()));
MemoryTracker.getInstance().incrementCachedAmount(deviceId, reqMemory);
MemoryTracker.getInstance().decrementAllocatedAmount(deviceId, reqMemory);
return;
} else {
cache.put(new CudaPointer(point.getDevicePointer().address()));
MemoryTracker.getInstance().incrementCachedAmount(deviceId, reqMemory);
MemoryTracker.getInstance().decrementAllocatedAmount(deviceId, reqMemory);
return;
}
}
super.free(point);
}
/**
* This method checks, if storage contains holder for specified shape
*
* @param deviceId
* @param shape
*/
protected void ensureDeviceCacheHolder(Integer deviceId, AllocationShape shape) {
if (!deviceCache.containsKey(deviceId)) {
try {
synchronized (this) {
if (!deviceCache.containsKey(deviceId)) {
deviceCache.put(deviceId, new ConcurrentHashMap<AllocationShape, CacheHolder>());
}
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
if (!deviceCache.get(deviceId).containsKey(shape)) {
try {
singleLock.acquire();
if (!deviceCache.get(deviceId).containsKey(shape)) {
deviceCache.get(deviceId).put(shape, new CacheHolder(shape, deviceCachedAmount.get(deviceId)));
}
} catch (Exception e) {
} finally {
singleLock.release();
}
}
}
@Override
protected synchronized void purgeCache(int deviceId) {
for (AllocationShape shape : deviceCache.get(deviceId).keySet()) {
Pointer ptr = null;
while ((ptr = deviceCache.get(deviceId).get(shape).poll()) != null) {
freeDevice(ptr, deviceId);
MemoryTracker.getInstance().decrementCachedAmount(deviceId, shape.getNumberOfBytes());
}
}
deviceCachedAmount.get(deviceId).set(0);
}
@Override
public synchronized void purgeCache() {
for (Integer device : deviceCache.keySet()) {
purgeCache(device);
}
super.purgeCache();
}
}

View File

@ -17,34 +17,39 @@
package org.nd4j.linalg.jcublas;
import com.google.flatbuffers.FlatBufferBuilder;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.bytedeco.javacpp.BytePointer;
import org.nd4j.base.Preconditions;
import org.nd4j.graph.FlatArray;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.enums.CudaConstants;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.DataTypeEx;
import org.nd4j.linalg.api.buffer.FloatBuffer;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.BaseNDArray;
import org.nd4j.linalg.api.ndarray.BaseNDArrayProxy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ndarray.JvmShapeInfo;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
import org.nd4j.linalg.api.ops.util.PrintVariable;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.CudaLongDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaUtf8Buffer;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.memory.MemcpyDirection;
import org.nd4j.linalg.workspace.WorkspaceUtils;
import org.nd4j.nativeblas.NativeOpsHolder;
import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
/**
*
@ -387,10 +392,6 @@ public class JCublasNDArray extends BaseNDArray {
super(data, order);
}
public JCublasNDArray(FloatBuffer floatBuffer, char order) {
super(floatBuffer, order);
}
public JCublasNDArray(DataBuffer buffer, int[] shape, int[] strides) {
super(buffer, shape, strides);
}
@ -574,26 +575,16 @@ public class JCublasNDArray extends BaseNDArray {
MemcpyDirection direction = MemcpyDirection.HOST_TO_HOST;
val prof = PerformanceTracker.getInstance().helperStartTransaction();
if (dstPoint.getAllocationStatus() == AllocationStatus.DEVICE && srcPoint.getAllocationStatus() == AllocationStatus.DEVICE) {
// d2d copy
if (srcPoint.isActualOnDeviceSide()) {
route = 1;
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(dstPoint.getDevicePointer(), srcPoint.getDevicePointer(), this.data.length() * this.data.getElementSize(), CudaConstants.cudaMemcpyDeviceToDevice, blocking ? context.getOldStream() : context.getSpecialStream());
dstPoint.tickDeviceWrite();
direction = MemcpyDirection.DEVICE_TO_DEVICE;
} else if (dstPoint.getAllocationStatus() == AllocationStatus.HOST && srcPoint.getAllocationStatus() == AllocationStatus.DEVICE) {
route = 2;
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(dstPoint.getHostPointer(), srcPoint.getDevicePointer(), this.data.length() * this.data.getElementSize(), CudaConstants.cudaMemcpyDeviceToHost, blocking ? context.getOldStream() : context.getSpecialStream());
dstPoint.tickHostWrite();
direction = MemcpyDirection.DEVICE_TO_HOST;
} else if (dstPoint.getAllocationStatus() == AllocationStatus.DEVICE && srcPoint.getAllocationStatus() == AllocationStatus.HOST) {
} else {
route = 3;
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(dstPoint.getDevicePointer(), srcPoint.getHostPointer(), this.data.length() * this.data.getElementSize(), CudaConstants.cudaMemcpyHostToDevice, blocking ? context.getOldStream() : context.getSpecialStream());
dstPoint.tickDeviceWrite();
direction = MemcpyDirection.HOST_TO_DEVICE;
} else {
route = 4;
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(dstPoint.getHostPointer(), srcPoint.getHostPointer(), this.data.length() * this.data.getElementSize(), CudaConstants.cudaMemcpyHostToHost, blocking ? context.getOldStream() : context.getSpecialStream());
dstPoint.tickHostWrite();
}
@ -650,30 +641,16 @@ public class JCublasNDArray extends BaseNDArray {
Nd4j.getMemoryManager().setCurrentWorkspace(target);
// log.info("Leveraging...");
INDArray copy = null;
if (!this.isView()) {
//if (1 < 0) {
Nd4j.getExecutioner().commit();
DataBuffer buffer = Nd4j.createBuffer(this.length(), false);
val buffer = Nd4j.createBuffer(this.length(), false);
AllocationPoint pointDst = AtomicAllocator.getInstance().getAllocationPoint(buffer);
AllocationPoint pointSrc = AtomicAllocator.getInstance().getAllocationPoint(this.data);
val pointDst = AtomicAllocator.getInstance().getAllocationPoint(buffer);
val pointSrc = AtomicAllocator.getInstance().getAllocationPoint(this.data);
CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(pointDst, pointSrc);
/*
if (NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(pointDst.getDevicePointer(), 0, 1, 0, context.getOldStream()) == 0)
throw new ND4JIllegalStateException("memsetAsync 1 failed");
context.syncOldStream();
if (NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(pointSrc.getDevicePointer(), 0, 1, 0, context.getOldStream()) == 0)
throw new ND4JIllegalStateException("memsetAsync 2 failed");
context.syncOldStream();
*/
val context = AtomicAllocator.getInstance().getFlowController().prepareAction(pointDst, pointSrc);
MemcpyDirection direction = MemcpyDirection.DEVICE_TO_DEVICE;
val perfD = PerformanceTracker.getInstance().helperStartTransaction();
@ -690,12 +667,11 @@ public class JCublasNDArray extends BaseNDArray {
context.syncOldStream();
PerformanceTracker.getInstance().helperRegisterTransaction(pointDst.getDeviceId(), perfD, pointSrc.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);
PerformanceTracker.getInstance().helperRegisterTransaction(pointDst.getDeviceId(), perfD, pointSrc.getNumberOfBytes(), direction);
copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer());
// tag buffer as valid on device side
pointDst.tickHostRead();
pointDst.tickDeviceWrite();
AtomicAllocator.getInstance().getFlowController().registerAction(context, pointDst, pointSrc);
@ -728,6 +704,7 @@ public class JCublasNDArray extends BaseNDArray {
val pointDst = AtomicAllocator.getInstance().getAllocationPoint(buffer);
val pointSrc = AtomicAllocator.getInstance().getAllocationPoint(this.data);
val context = AtomicAllocator.getInstance().getFlowController().prepareAction(pointDst, pointSrc);
MemcpyDirection direction = MemcpyDirection.DEVICE_TO_DEVICE;
@ -764,6 +741,38 @@ public class JCublasNDArray extends BaseNDArray {
return copy;
}
protected int stringBuffer(FlatBufferBuilder builder, DataBuffer buffer) {
Preconditions.checkArgument(buffer.dataType() == DataType.UTF8, "This method can be called on UTF8 buffers only");
try {
ByteArrayOutputStream bos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(bos);
val numWords = this.length();
val ub = (CudaUtf8Buffer) buffer;
// writing length first
val t = length();
val ptr = (BytePointer) ub.pointer();
// now write all strings as bytes
for (int i = 0; i < ub.length(); i++) {
dos.writeByte(ptr.get(i));
}
val bytes = bos.toByteArray();
return FlatArray.createBufferVector(builder, bytes);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
@Override
public String getString(long index) {
if (!isS())
throw new UnsupportedOperationException("This method is usable only on String dataType, but got [" + this.dataType() + "]");
return ((CudaUtf8Buffer) data).getString(index);
}
/*
@Override
public INDArray convertToHalfs() {

View File

@ -18,11 +18,9 @@ package org.nd4j.linalg.jcublas;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import lombok.var;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.DataTypeEx;
import org.nd4j.linalg.api.buffer.Utf8Buffer;
import org.nd4j.linalg.api.memory.enums.MemoryKind;
import org.nd4j.linalg.api.ops.custom.Flatten;
import org.nd4j.linalg.api.ops.impl.shape.Concat;
@ -34,12 +32,10 @@ import org.nd4j.linalg.jcublas.buffer.*;
import org.nd4j.linalg.memory.MemcpyDirection;
import org.nd4j.linalg.primitives.Pair;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.indexer.*;
import org.nd4j.jita.allocator.enums.CudaConstants;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
@ -51,19 +47,12 @@ import org.nd4j.linalg.compression.CompressedDataBuffer;
import org.nd4j.linalg.compression.CompressionDescriptor;
import org.nd4j.linalg.compression.CompressionType;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.BaseNDArrayFactory;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.blas.*;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.nativeblas.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.Charset;
import java.util.*;
/**
@ -216,7 +205,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
@Override
public INDArray create(Collection<String> strings, long[] shape, char order) {
val pairShape = Nd4j.getShapeInfoProvider().createShapeInformation(shape, order, DataType.UTF8);
val buffer = new Utf8Buffer(strings);
val buffer = new CudaUtf8Buffer(strings);
val list = new ArrayList<String>(strings);
return Nd4j.createArrayFromShapeBuffer(buffer, pairShape);
}
@ -360,8 +349,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
@Override
public INDArray concat(int dimension, INDArray... toConcat) {
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
Nd4j.getExecutioner().push();
return Nd4j.exec(new Concat(dimension, toConcat))[0];
}
@ -517,9 +505,9 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
AtomicAllocator allocator = AtomicAllocator.getInstance();
CudaContext context = allocator.getFlowController().prepareAction(ret, source);
Pointer x = AtomicAllocator.getInstance().getPointer(source, context);
val x = ((BaseCudaDataBuffer) source.data()).getOpaqueDataBuffer();
val z = ((BaseCudaDataBuffer) ret.data()).getOpaqueDataBuffer();
Pointer xShape = AtomicAllocator.getInstance().getPointer(source.shapeInfoDataBuffer(), context);
Pointer z = AtomicAllocator.getInstance().getPointer(ret, context);
Pointer zShape = AtomicAllocator.getInstance().getPointer(ret.shapeInfoDataBuffer(), context);
PointerPointer extras = new PointerPointer(AddressRetriever.retrieveHostPointer(ret.shapeInfoDataBuffer()),
@ -545,14 +533,8 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
nativeOps.pullRows(extras,
null,
(LongPointer) source.shapeInfoDataBuffer().addressPointer(),
x,
(LongPointer) xShape,
null,
(LongPointer) ret.shapeInfoDataBuffer().addressPointer(),
z,
(LongPointer) zShape,
x, (LongPointer) source.shapeInfoDataBuffer().addressPointer(), (LongPointer) xShape,
z, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), (LongPointer) zShape,
indexes.length,
(LongPointer) pIndex,
(LongPointer) tadShapeInfo,
@ -601,7 +583,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
throw new ND4JIllegalStateException("All arrays should have equal length for averaging");
AllocationPoint point = allocator.getAllocationPoint(arrays[i]);
xPointers[i] = point.getPointers().getDevicePointer().address();
xPointers[i] = point.getDevicePointer().address();
point.tickDeviceWrite();
}
@ -710,7 +692,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
throw new ND4JIllegalStateException("All arrays should have equal length for averaging");
AllocationPoint point = allocator.getAllocationPoint(arrays[i]);
xPointers[i] = point.getPointers().getDevicePointer().address();
xPointers[i] = point.getDevicePointer().address();
point.tickDeviceWrite();
}
@ -1324,11 +1306,11 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
PointerPointer extraz = new PointerPointer(null, // not used
context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer());
val x = ((BaseCudaDataBuffer) tensor.data()).getOpaqueDataBuffer();
nativeOps.tear(extraz,
null,
(LongPointer) tensor.shapeInfoDataBuffer().addressPointer(),
AtomicAllocator.getInstance().getPointer(tensor, context),
(LongPointer) AtomicAllocator.getInstance().getPointer(tensor.shapeInfoDataBuffer(), context),
x, (LongPointer) tensor.shapeInfoDataBuffer().addressPointer(), (LongPointer) AtomicAllocator.getInstance().getPointer(tensor.shapeInfoDataBuffer(), context),
new PointerPointer(AtomicAllocator.getInstance().getPointer(tempX, context)),
(LongPointer) AtomicAllocator.getInstance().getPointer(result[0].shapeInfoDataBuffer(), context),
(LongPointer) AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context),

View File

@ -46,6 +46,10 @@ public class CudaBfloat16DataBuffer extends BaseCudaDataBuffer {
super(pointer, specialPointer, indexer, length);
}
public CudaBfloat16DataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) {
super(buffer, dataType, length, offset);
}
/**
* Base constructor
*
@ -128,18 +132,6 @@ public class CudaBfloat16DataBuffer extends BaseCudaDataBuffer {
super(data, copy, offset);
}
public CudaBfloat16DataBuffer(byte[] data, long length) {
super(data, length, DataType.BFLOAT16);
}
public CudaBfloat16DataBuffer(ByteBuffer buffer, long length) {
super(buffer, (int) length, DataType.BFLOAT16);
}
public CudaBfloat16DataBuffer(ByteBuffer buffer, long length, long offset) {
super(buffer, length, offset, DataType.BFLOAT16);
}
@Override
public void assign(long[] indices, double[] data, boolean contiguous, long inc) {

View File

@ -50,6 +50,10 @@ public class CudaBoolDataBuffer extends BaseCudaDataBuffer {
super(pointer, specialPointer, indexer, length);
}
public CudaBoolDataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) {
super(buffer, dataType, length, offset);
}
/**
* Base constructor
*
@ -132,18 +136,6 @@ public class CudaBoolDataBuffer extends BaseCudaDataBuffer {
super(data, copy, offset);
}
public CudaBoolDataBuffer(byte[] data, long length) {
super(data, length, DataType.HALF);
}
public CudaBoolDataBuffer(ByteBuffer buffer, long length) {
super(buffer, (int) length, DataType.HALF);
}
public CudaBoolDataBuffer(ByteBuffer buffer, long length, long offset) {
super(buffer, length, offset, DataType.HALF);
}
@Override
protected DataBuffer create(long length) {
return new CudaBoolDataBuffer(length);

View File

@ -49,6 +49,10 @@ public class CudaByteDataBuffer extends BaseCudaDataBuffer {
super(pointer, specialPointer, indexer, length);
}
public CudaByteDataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) {
super(buffer, dataType, length, offset);
}
/**
* Base constructor
*
@ -131,18 +135,6 @@ public class CudaByteDataBuffer extends BaseCudaDataBuffer {
super(data, copy, offset);
}
public CudaByteDataBuffer(byte[] data, long length) {
super(data, length, DataType.HALF);
}
public CudaByteDataBuffer(ByteBuffer buffer, long length) {
super(buffer, (int) length, DataType.HALF);
}
public CudaByteDataBuffer(ByteBuffer buffer, long length, long offset) {
super(buffer, length, offset, DataType.HALF);
}
@Override
protected DataBuffer create(long length) {
return new CudaByteDataBuffer(length);

View File

@ -49,6 +49,10 @@ public class CudaDoubleDataBuffer extends BaseCudaDataBuffer {
super(pointer, specialPointer, indexer, length);
}
public CudaDoubleDataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) {
super(buffer, dataType, length, offset);
}
/**
* Base constructor
*
@ -138,18 +142,6 @@ public class CudaDoubleDataBuffer extends BaseCudaDataBuffer {
super(data, copy, offset);
}
public CudaDoubleDataBuffer(byte[] data, long length) {
super(data, length, DataType.DOUBLE);
}
public CudaDoubleDataBuffer(ByteBuffer buffer, long length) {
super(buffer, (int) length, DataType.DOUBLE);
}
public CudaDoubleDataBuffer(ByteBuffer buffer, long length, long offset) {
super(buffer, length, offset, DataType.DOUBLE);
}
@Override
protected DataBuffer create(long length) {
return new CudaDoubleDataBuffer(length);
@ -210,14 +202,7 @@ public class CudaDoubleDataBuffer extends BaseCudaDataBuffer {
this.length = n;
this.elementSize = 8;
//wrappedBuffer = ByteBuffer.allocateDirect(length() * getElementSize());
//wrappedBuffer.order(ByteOrder.nativeOrder());
this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this,
new AllocationShape(length, elementSize, DataType.DOUBLE), false);
this.trackingPoint = allocationPoint.getObjectId();
//this.wrappedBuffer = allocationPoint.getPointers().getHostPointer().asByteBuffer();
//this.wrappedBuffer.order(ByteOrder.nativeOrder());
this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, DataType.DOUBLE), false);
setData(arr);
}

View File

@ -50,6 +50,10 @@ public class CudaFloatDataBuffer extends BaseCudaDataBuffer {
super(pointer, specialPointer, indexer, length);
}
public CudaFloatDataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) {
super(buffer, dataType, length, offset);
}
/**
* Base constructor
*
@ -133,19 +137,6 @@ public class CudaFloatDataBuffer extends BaseCudaDataBuffer {
super(data, copy, offset);
}
public CudaFloatDataBuffer(byte[] data, long length) {
super(data, length, DataType.FLOAT);
}
public CudaFloatDataBuffer(ByteBuffer buffer, long length) {
super(buffer, (int) length, DataType.FLOAT);
}
public CudaFloatDataBuffer(ByteBuffer buffer, long length, long offset) {
super(buffer, length, offset, DataType.FLOAT);
}
@Override
protected DataBuffer create(long length) {
return new CudaFloatDataBuffer(length);

Some files were not shown because too many files have changed in this diff Show More