String changes (#3)
* initial commit * additional data types & tensor type Signed-off-by: raver119 <raver119@gmail.com> * next step Signed-off-by: raver119 <raver119@gmail.com> * missing include * sparse_to_dense Signed-off-by: raver119 <raver119@gmail.com> * few more tests files Signed-off-by: raver119 <raver119@gmail.com> * draft Signed-off-by: raver119 <raver119@gmail.com> * numeric sparse_to_dense Signed-off-by: raver119 <raver119@gmail.com> * comment Signed-off-by: raver119 <raver119@gmail.com> * string sparse_to_dense version Signed-off-by: raver119 <raver119@gmail.com> * CUDA DataBuffer expand Signed-off-by: raver119 <raver119@gmail.com> * few tweaks for CUDA build Signed-off-by: raver119 <raver119@gmail.com> * shape fn for string_split Signed-off-by: raver119 <raver119@gmail.com> * one more comment Signed-off-by: raver119 <raver119@gmail.com> * string_split indices Signed-off-by: raver119 <raver119@gmail.com> * next step Signed-off-by: raver119 <raver119@gmail.com> * test passes Signed-off-by: raver119 <raver119@gmail.com> * few rearrangements for databuffer implementations Signed-off-by: raver119 <raver119@gmail.com> * DataBuffer: move inline methods to common implementations Signed-off-by: raver119 <raver119@gmail.com> * add native DataBuffer to Nd4j presets Signed-off-by: raver119 <raver119@gmail.com> * DataBuffer creation Signed-off-by: raver119 <raver119@gmail.com> * use DataBuffer for allocation Signed-off-by: raver119 <raver119@gmail.com> * cpu databuffer as deallocatable Signed-off-by: raver119 <raver119@gmail.com> * DataBuffer setters for bufers Signed-off-by: raver119 <raver119@gmail.com> * couple of wrappers Signed-off-by: raver119 <raver119@gmail.com> * DataBuffers being passed around Signed-off-by: raver119 <raver119@gmail.com> * Bunch of ByteBuffer-related signatures gone Signed-off-by: raver119 <raver119@gmail.com> * - few more Nd4j signatures removed - minor fix for bfloat16 Signed-off-by: raver119 <raver119@gmail.com> * nullptr pointer is still a pointer, but 0 as address :) Signed-off-by: raver119 <raver119@gmail.com> * one special test Signed-off-by: raver119 <raver119@gmail.com> * empty string array init Signed-off-by: raver119 <raver119@gmail.com> * one more test in cpp Signed-off-by: raver119 <raver119@gmail.com> * memcpy instead of databuffer swap Signed-off-by: raver119 <raver119@gmail.com> * special InteropDataBuffer for front-end languages Signed-off-by: raver119 <raver119@gmail.com> * few tweaks for java Signed-off-by: raver119 <raver119@gmail.com> * pointer/indexer actualization Signed-off-by: raver119 <raver119@gmail.com> * CustomOp returns list for inputArumgents and outputArguments instead of array Signed-off-by: raver119 <raver119@gmail.com> * redundant call Signed-off-by: raver119 <raver119@gmail.com> * print_variable op Signed-off-by: raver119 <raver119@gmail.com> * - view handling (but wrong one) - print_variable java wrapper Signed-off-by: raver119 <raver119@gmail.com> * one more test Signed-off-by: raver119 <raver119@gmail.com> * - empty arrays handling Signed-off-by: raver119 <raver119@gmail.com> * - deserialization works now Signed-off-by: raver119 <raver119@gmail.com> * minor fix Signed-off-by: raver119 <raver119@gmail.com> * meh Signed-off-by: raver119 <raver119@gmail.com> * one more fix Signed-off-by: raver119 <raver119@gmail.com> * initial cuda commit Signed-off-by: raver119 <raver119@gmail.com> * print_variable message validation Signed-off-by: raver119 <raver119@gmail.com> * CUDA views Signed-off-by: raver119 <raver119@gmail.com> * CUDA special buffer size Signed-off-by: raver119 <raver119@gmail.com> * minor update to match master changes Signed-off-by: raver119 <raver119@gmail.com> * - consider arrays always actual on device for CUDA - additional PrintVariable constructor - CudaUtf8Buffer now allocates host buffer by default Signed-off-by: raver119 <raver119@gmail.com> * meh Signed-off-by: raver119 <raver119@gmail.com> * - print_variable now allows print from device Signed-off-by: raver119 <raver119@gmail.com> * InteropDataBuffer data type fix Signed-off-by: raver119 <raver119@gmail.com> * ... Signed-off-by: raver119 <raver119@gmail.com> * disable some debug messages Signed-off-by: raver119 <raver119@gmail.com> * master pulled in Signed-off-by: raver119 <raver119@gmail.com> * couple of new methods for DataBuffer interop Signed-off-by: raver119 <raver119@gmail.com> * java side Signed-off-by: raver119 <raver119@gmail.com> * offsetted constructor Signed-off-by: raver119 <raver119@gmail.com> * new CUDA deallocator Signed-off-by: raver119 <raver119@gmail.com> * CUDA backend torn apart Signed-off-by: raver119 <raver119@gmail.com> * CUDA backend torn apart 2 Signed-off-by: raver119 <raver119@gmail.com> * CUDA backend torn apart 3 Signed-off-by: raver119 <raver119@gmail.com> * - few new tests - few new methods for DataBuffer management Signed-off-by: raver119 <raver119@gmail.com> * few more tests + few more tweaks Signed-off-by: raver119 <raver119@gmail.com> * two failing tests Signed-off-by: raver119 <raver119@gmail.com> * one more test Signed-off-by: raver119 <raver119@gmail.com> * two failing tests pass Signed-off-by: raver119 <raver119@gmail.com> * now we pass DataBuffer to legacy ops too Signed-off-by: raver119 <raver119@gmail.com> * Native DataBuffer for legacy ops, Java side Signed-off-by: raver119 <raver119@gmail.com> * CPU java side update Signed-off-by: raver119 <raver119@gmail.com> * CUDA java side update Signed-off-by: raver119 <raver119@gmail.com> * no more prepare/register action on java side Signed-off-by: raver119 <raver119@gmail.com> * NDArray::prepare/register use now accepts vectors Signed-off-by: raver119 <raver119@gmail.com> * InteropDataBuffer now has few more convenience methods Signed-off-by: raver119 <raver119@gmail.com> * java bindings update Signed-off-by: raver119 <raver119@gmail.com> * tick device in NativeOps Signed-off-by: raver119 <raver119@gmail.com> * Corrected usage of OpaqueBuffer for tests. * Corrected usage of OpaqueBuffer for java tests. * NativeOpsTests fixes. * print_variable now returns scalar Signed-off-by: raver119 <raver119@gmail.com> * one more test Signed-off-by: raver119 <raver119@gmail.com> * compat_string_split fix for CUDA Signed-off-by: raver119 <raver119@gmail.com> * - CUDA execScalar fix - CUDA lazyAllocateHostPointer now checks java indexer/pointer instead of native pointer Signed-off-by: raver119 <raver119@gmail.com> * legacy ops DataBuffer migration prototype Signed-off-by: raver119 <raver119@gmail.com> * ignore device shapeinfo coming from java Signed-off-by: raver119 <raver119@gmail.com> * minor fix Signed-off-by: raver119 <raver119@gmail.com> * minor transformAny fix Signed-off-by: raver119 <raver119@gmail.com> * minor tweak for lazy host allocation Signed-off-by: raver119 <raver119@gmail.com> * - DataBuffer::memcpy method - bitcast now uses memcpy Signed-off-by: raver119 <raver119@gmail.com> * - IndexReduce CUDA dimension buffer fix Signed-off-by: raver119 <raver119@gmail.com> * views for CPU and CUDA Signed-off-by: raver119 <raver119@gmail.com> * less spam Signed-off-by: raver119 <raver119@gmail.com> * optional memory init Signed-off-by: raver119 <raver119@gmail.com> * async memset Signed-off-by: raver119 <raver119@gmail.com> * - SummaryStats CUDA fix - DataBuffer.sameUnderlyingData() impl - execBroadcast fix Signed-off-by: raver119 <raver119@gmail.com> * - reduce3All fix switch to CUDA 10 temporarily Signed-off-by: raver119 <raver119@gmail.com> * CUDA version Signed-off-by: raver119 <raver119@gmail.com> * proper memory deallocator registration Signed-off-by: raver119 <raver119@gmail.com> * HOST_ONLY workspace allocation Signed-off-by: raver119 <raver119@gmail.com> * temp commit Signed-off-by: raver119 <raver119@gmail.com> * few conflicts resolved Signed-off-by: raver119 <raver119@gmail.com> * few minor fixes Signed-off-by: raver119 <raver119@gmail.com> * one more minor fix Signed-off-by: raver119 <raver119@gmail.com> * NDArray permute should operate on JVM primitives Signed-off-by: raver119 <raver119@gmail.com> * - create InteropDataBuffer for shapes as well - update pointers after view creation in Java Signed-off-by: raver119 <raver119@gmail.com> * - addressPointer temporary moved to C++ Signed-off-by: raver119 <raver119@gmail.com> * CUDA: don't account offset twice Signed-off-by: raver119 <raver119@gmail.com> * CUDA: DataBuffer pointer constructor updated Signed-off-by: raver119 <raver119@gmail.com> * CUDA NDArray.unsafeDuplication() simplified Signed-off-by: raver119 <raver119@gmail.com> * CUDA minor workspace-related fixes Signed-off-by: raver119 <raver119@gmail.com> * CPU DataBuffer.reallocate() Signed-off-by: raver119 <raver119@gmail.com> * print_affinity op Signed-off-by: raver119 <raver119@gmail.com> * print_affinity java side Signed-off-by: raver119 <raver119@gmail.com> * CUDA more tweaks for data locality Signed-off-by: raver119 <raver119@gmail.com> * - compat_string_split tweak - CudaUtf8Buffer update Signed-off-by: raver119 <raver119@gmail.com> * INDArray.close() mechanic restored Signed-off-by: raver119 <raver119@gmail.com> * one more test fixed Signed-off-by: raver119 <raver119@gmail.com> * - CUDA DataBuffer.reallocate() updated - cudaMemcpy (synchronous) restored Signed-off-by: raver119 <raver119@gmail.com> * one last fix Signed-off-by: raver119 <raver119@gmail.com> * bad import removed Signed-off-by: raver119 <raver119@gmail.com> * another small fix Signed-off-by: raver119 <raver119@gmail.com> * one special test Signed-off-by: raver119 <raver119@gmail.com> * fix bad databuffer size Signed-off-by: raver119 <raver119@gmail.com> * release primaryBuffer on replace Signed-off-by: raver119 <raver119@gmail.com> * higher timeout Signed-off-by: raver119 <raver119@gmail.com> * disable timeouts Signed-off-by: raver119 <raver119@gmail.com> * dbCreateView now validates offset and length of a view Signed-off-by: raver119 <raver119@gmail.com> * additional validation for dbExpand Signed-off-by: raver119 <raver119@gmail.com> * restore timeout back again Signed-off-by: raver119 <raver119@gmail.com> * smaller distribution for rng test to prevent timeouts Signed-off-by: raver119 <raver119@gmail.com> * CUDA DataBuffer::memcpy now copies to device all the time Signed-off-by: raver119 <raver119@gmail.com> * OpaqueDataBuffer now contains all required methods for interop Signed-off-by: raver119 <raver119@gmail.com> * some javadoc Signed-off-by: raver119 <raver119@gmail.com> * GC on failed allocations Signed-off-by: raver119 <raver119@gmail.com> * minoe memcpu tweak Signed-off-by: raver119 <raver119@gmail.com> * one more bitcast test Signed-off-by: raver119 <raver119@gmail.com> * - NDArray::deviceId() propagation - special multi-threaded test for data locality checks Signed-off-by: raver119 <raver119@gmail.com> * DataBuffer additional syncStream Signed-off-by: raver119 <raver119@gmail.com> * DataBuffer additional syncStream Signed-off-by: raver119 <raver119@gmail.com> * one ignored test Signed-off-by: raver119 <raver119@gmail.com> * skip host alloc for empty arrays Signed-off-by: raver119 <raver119@gmail.com> * ByteBuffer support is back Signed-off-by: raver119 <raver119@gmail.com> * DataBuffer::memcpy minor fix Signed-off-by: raver119 <raver119@gmail.com> * few minor prelu/bp tweaks Signed-off-by: raver119 <raver119@gmail.com> * nullify-related fixes Signed-off-by: raver119 <raver119@gmail.com> * PReLU fixes (#157) Signed-off-by: Alex Black <blacka101@gmail.com> * Build fixed * Fix tests * one more ByteBuffer signature restored Signed-off-by: raver119 <raver119@gmail.com> * nd4j-jdbc-hsql profiles fix Signed-off-by: raver119 <raver119@gmail.com> * nd4j-jdbc-hsql profiles fix Signed-off-by: raver119 <raver119@gmail.com> * PReLU weight init fix Signed-off-by: Alex Black <blacka101@gmail.com> * Small PReLU fix Signed-off-by: Alex Black <blacka101@gmail.com> * - INDArray.migrate() reactivated - DataBuffer::setDeviceId(...) added - InteropDataBuffer Z syncToDevice added for views Signed-off-by: raver119 <raver119@gmail.com> * missed file Signed-off-by: raver119 <raver119@gmail.com> * Small tweak Signed-off-by: Alex Black <blacka101@gmail.com> * cuda 10.2 Signed-off-by: raver119 <raver119@gmail.com> * minor fix Signed-off-by: raver119 <raver119@gmail.com> Co-authored-by: shugeo <sgazeos@gmail.com> Co-authored-by: Alex Black <blacka101@gmail.com> Co-authored-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>master
parent
451d9d57fd
commit
29e8e09db6
|
@ -121,6 +121,7 @@ public class PReLULayer extends BaseLayer {
|
||||||
public static class Builder extends FeedForwardLayer.Builder<PReLULayer.Builder> {
|
public static class Builder extends FeedForwardLayer.Builder<PReLULayer.Builder> {
|
||||||
|
|
||||||
public Builder(){
|
public Builder(){
|
||||||
|
//Default to 0s, and don't inherit global default
|
||||||
this.weightInitFn = new WeightInitConstant(0);
|
this.weightInitFn = new WeightInitConstant(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ import lombok.Getter;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import org.deeplearning4j.models.word2vec.VocabWord;
|
import org.deeplearning4j.models.word2vec.VocabWord;
|
||||||
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
||||||
import org.nd4j.linalg.api.buffer.FloatBuffer;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
@ -63,7 +63,7 @@ public class NegativeHolder implements Serializable {
|
||||||
|
|
||||||
protected void makeTable(int tableSize, double power) {
|
protected void makeTable(int tableSize, double power) {
|
||||||
int vocabSize = vocab.numWords();
|
int vocabSize = vocab.numWords();
|
||||||
table = Nd4j.create(new FloatBuffer(tableSize));
|
table = Nd4j.create(DataType.FLOAT, tableSize);
|
||||||
double trainWordsPow = 0.0;
|
double trainWordsPow = 0.0;
|
||||||
for (String word : vocab.words()) {
|
for (String word : vocab.words()) {
|
||||||
trainWordsPow += Math.pow(vocab.wordFrequency(word), power);
|
trainWordsPow += Math.pow(vocab.wordFrequency(word), power);
|
||||||
|
|
|
@ -42,6 +42,8 @@
|
||||||
#include <helpers/ConstantShapeHelper.h>
|
#include <helpers/ConstantShapeHelper.h>
|
||||||
#include <array/DataBuffer.h>
|
#include <array/DataBuffer.h>
|
||||||
#include <execution/AffinityManager.h>
|
#include <execution/AffinityManager.h>
|
||||||
|
#include <memory>
|
||||||
|
#include <array/InteropDataBuffer.h>
|
||||||
|
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
|
@ -301,14 +303,11 @@ namespace nd4j {
|
||||||
* @param writeList
|
* @param writeList
|
||||||
* @param readList
|
* @param readList
|
||||||
*/
|
*/
|
||||||
// TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list
|
static void registerSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList);
|
||||||
static void registerSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList);
|
static void prepareSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables = false);
|
||||||
static void prepareSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables = false);
|
|
||||||
|
|
||||||
// TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list
|
|
||||||
static void registerPrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList);
|
|
||||||
static void preparePrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables = false);
|
|
||||||
|
|
||||||
|
static void registerPrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList);
|
||||||
|
static void preparePrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables = false);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method returns buffer pointer offset by given number of elements, wrt own data type
|
* This method returns buffer pointer offset by given number of elements, wrt own data type
|
||||||
|
|
|
@ -223,6 +223,8 @@ NDArray::NDArray(std::shared_ptr<DataBuffer> buffer, const ShapeDescriptor& desc
|
||||||
setShapeInfo(descriptor);
|
setShapeInfo(descriptor);
|
||||||
|
|
||||||
_buffer = buffer;
|
_buffer = buffer;
|
||||||
|
|
||||||
|
_isView = offset > 0 || _length * DataTypeUtils::sizeOf(_dataType) < buffer->getLenInBytes();
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -288,6 +290,8 @@ NDArray::NDArray(std::shared_ptr<DataBuffer> buffer, const char order, const std
|
||||||
setShapeInfo(ShapeDescriptor(buffer->getDataType(), order, shape));
|
setShapeInfo(ShapeDescriptor(buffer->getDataType(), order, shape));
|
||||||
|
|
||||||
_buffer = buffer;
|
_buffer = buffer;
|
||||||
|
|
||||||
|
_isView = _length * DataTypeUtils::sizeOf(_dataType) < buffer->getLenInBytes();
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -68,6 +68,7 @@ bool verbose = false;
|
||||||
#include <array/ConstantDescriptor.h>
|
#include <array/ConstantDescriptor.h>
|
||||||
#include <helpers/ConstantShapeHelper.h>
|
#include <helpers/ConstantShapeHelper.h>
|
||||||
#include <array/ConstantDataBuffer.h>
|
#include <array/ConstantDataBuffer.h>
|
||||||
|
#include <array/InteropDataBuffer.h>
|
||||||
#include <helpers/ConstantHelper.h>
|
#include <helpers/ConstantHelper.h>
|
||||||
#include <array/TadPack.h>
|
#include <array/TadPack.h>
|
||||||
#include <graph/VariablesSet.h>
|
#include <graph/VariablesSet.h>
|
||||||
|
@ -76,6 +77,8 @@ bool verbose = false;
|
||||||
#include <graph/ResultWrapper.h>
|
#include <graph/ResultWrapper.h>
|
||||||
#include <DebugInfo.h>
|
#include <DebugInfo.h>
|
||||||
|
|
||||||
|
typedef nd4j::InteropDataBuffer OpaqueDataBuffer;
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -118,11 +121,9 @@ ND4J_EXPORT void setTADThreshold(int num);
|
||||||
*/
|
*/
|
||||||
ND4J_EXPORT void execIndexReduceScalar(Nd4jPointer *extraPointers,
|
ND4J_EXPORT void execIndexReduceScalar(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo);
|
||||||
void *dZ, Nd4jLong *dZShapeInfo);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
@ -137,13 +138,10 @@ ND4J_EXPORT void execIndexReduceScalar(Nd4jPointer *extraPointers,
|
||||||
*/
|
*/
|
||||||
ND4J_EXPORT void execIndexReduce(Nd4jPointer *extraPointers,
|
ND4J_EXPORT void execIndexReduce(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape);
|
||||||
void *hDimension, Nd4jLong *hDimensionShape,
|
|
||||||
void *dDimension, Nd4jLong *dDimensionShape);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
@ -160,28 +158,20 @@ ND4J_EXPORT void execIndexReduce(Nd4jPointer *extraPointers,
|
||||||
ND4J_EXPORT void execBroadcast(
|
ND4J_EXPORT void execBroadcast(
|
||||||
Nd4jPointer *extraPointers,
|
Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
|
||||||
void *hY, Nd4jLong *hYShapeInfo,
|
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
|
||||||
void *dY, Nd4jLong *dYShapeInfo,
|
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape);
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
|
||||||
void *hDimension, Nd4jLong *hDimensionShape,
|
|
||||||
void *dDimension, Nd4jLong *dDimensionShape);
|
|
||||||
|
|
||||||
|
|
||||||
ND4J_EXPORT void execBroadcastBool(
|
ND4J_EXPORT void execBroadcastBool(
|
||||||
Nd4jPointer *extraPointers,
|
Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
|
||||||
void *hY, Nd4jLong *hYShapeInfo,
|
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
|
||||||
void *dY, Nd4jLong *dYShapeInfo,
|
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *hDimension, Nd4jLong *hDimensionShape,
|
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape);
|
||||||
void *dDimension, Nd4jLong *dDimensionShape);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
@ -198,23 +188,17 @@ ND4J_EXPORT void execBroadcastBool(
|
||||||
ND4J_EXPORT void execPairwiseTransform(
|
ND4J_EXPORT void execPairwiseTransform(
|
||||||
Nd4jPointer *extraPointers,
|
Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
|
||||||
void *hY, Nd4jLong *hYShapeInfo,
|
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
|
||||||
void *dY, Nd4jLong *dYShapeInfo,
|
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
|
||||||
void *extraParams);
|
void *extraParams);
|
||||||
|
|
||||||
ND4J_EXPORT void execPairwiseTransformBool(
|
ND4J_EXPORT void execPairwiseTransformBool(
|
||||||
Nd4jPointer *extraPointers,
|
Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
|
||||||
void *hY, Nd4jLong *hYShapeInfo,
|
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
|
||||||
void *dY, Nd4jLong *dYShapeInfo,
|
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
|
||||||
void *extraParams);
|
void *extraParams);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -228,36 +212,28 @@ ND4J_EXPORT void execPairwiseTransformBool(
|
||||||
*/
|
*/
|
||||||
ND4J_EXPORT void execReduceFloat(Nd4jPointer *extraPointers,
|
ND4J_EXPORT void execReduceFloat(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo);
|
||||||
void *dZ, Nd4jLong *dZShapeInfo);
|
|
||||||
|
|
||||||
ND4J_EXPORT void execReduceSame(Nd4jPointer *extraPointers,
|
ND4J_EXPORT void execReduceSame(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo);
|
||||||
void *dZ, Nd4jLong *dZShapeInfo);
|
|
||||||
|
|
||||||
ND4J_EXPORT void execReduceBool(Nd4jPointer *extraPointers,
|
ND4J_EXPORT void execReduceBool(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo);
|
||||||
void *dZ, Nd4jLong *dZShapeInfo);
|
|
||||||
|
|
||||||
|
|
||||||
ND4J_EXPORT void execReduceLong(Nd4jPointer *extraPointers,
|
ND4J_EXPORT void execReduceLong(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo);
|
||||||
void *dZ, Nd4jLong *dZShapeInfo);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
@ -270,46 +246,34 @@ ND4J_EXPORT void execReduceLong(Nd4jPointer *extraPointers,
|
||||||
*/
|
*/
|
||||||
ND4J_EXPORT void execReduceFloat2(Nd4jPointer *extraPointers,
|
ND4J_EXPORT void execReduceFloat2(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape);
|
||||||
void *hDimension, Nd4jLong *hDimensionShape,
|
|
||||||
void *dDimension, Nd4jLong *dDimensionShape);
|
|
||||||
|
|
||||||
|
|
||||||
ND4J_EXPORT void execReduceSame2(Nd4jPointer *extraPointers,
|
ND4J_EXPORT void execReduceSame2(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape);
|
||||||
void *hDimension, Nd4jLong *hDimensionShape,
|
|
||||||
void *dDimension, Nd4jLong *dDimensionShape);
|
|
||||||
|
|
||||||
|
|
||||||
ND4J_EXPORT void execReduceBool2(Nd4jPointer *extraPointers,
|
ND4J_EXPORT void execReduceBool2(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape);
|
||||||
void *hDimension, Nd4jLong *hDimensionShape,
|
|
||||||
void *dDimension, Nd4jLong *dDimensionShape);
|
|
||||||
|
|
||||||
|
|
||||||
ND4J_EXPORT void execReduceLong2(Nd4jPointer *extraPointers,
|
ND4J_EXPORT void execReduceLong2(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape);
|
||||||
void *hDimension, Nd4jLong *hDimensionShape,
|
|
||||||
void *dDimension, Nd4jLong *dDimensionShape);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
@ -324,13 +288,10 @@ ND4J_EXPORT void execReduceLong2(Nd4jPointer *extraPointers,
|
||||||
*/
|
*/
|
||||||
ND4J_EXPORT void execReduce3(Nd4jPointer *extraPointers,
|
ND4J_EXPORT void execReduce3(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
|
||||||
void *extraParamsVals,
|
void *extraParamsVals,
|
||||||
void *hY, Nd4jLong *hYShapeInfo,
|
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
|
||||||
void *dY, Nd4jLong *dYShapeInfo,
|
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo);
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
|
||||||
void *dZ, Nd4jLong *dZShapeInfo);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
@ -343,13 +304,10 @@ ND4J_EXPORT void execReduce3(Nd4jPointer *extraPointers,
|
||||||
*/
|
*/
|
||||||
ND4J_EXPORT void execReduce3Scalar(Nd4jPointer *extraPointers,
|
ND4J_EXPORT void execReduce3Scalar(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
|
||||||
void *extraParamsVals,
|
void *extraParamsVals,
|
||||||
void *hY, Nd4jLong *hYShapeInfo,
|
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
|
||||||
void *dY, Nd4jLong *dYShapeInfo,
|
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo);
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
|
||||||
void *dZ, Nd4jLong *dZShapeInfo);
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* @param opNum
|
* @param opNum
|
||||||
|
@ -365,30 +323,22 @@ ND4J_EXPORT void execReduce3Scalar(Nd4jPointer *extraPointers,
|
||||||
*/
|
*/
|
||||||
ND4J_EXPORT void execReduce3Tad(Nd4jPointer *extraPointers,
|
ND4J_EXPORT void execReduce3Tad(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
|
||||||
void *extraParamsVals,
|
void *extraParamsVals,
|
||||||
void *hY, Nd4jLong *hYShapeInfo,
|
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
|
||||||
void *dY, Nd4jLong *dYShapeInfo,
|
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
|
||||||
void *hDimension, Nd4jLong *hDimensionShape,
|
|
||||||
void *dDimension, Nd4jLong *dDimensionShape,
|
|
||||||
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
||||||
Nd4jLong *yTadOnlyShapeInfo, Nd4jLong *yTadOffsets);
|
Nd4jLong *yTadOnlyShapeInfo, Nd4jLong *yTadOffsets);
|
||||||
|
|
||||||
|
|
||||||
ND4J_EXPORT void execReduce3All(Nd4jPointer *extraPointers,
|
ND4J_EXPORT void execReduce3All(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
|
||||||
void *extraParamsVals,
|
void *extraParamsVals,
|
||||||
void *hY, Nd4jLong *hYShapeInfo,
|
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo,
|
||||||
void *dY, Nd4jLong *dYShapeInfo,
|
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
|
||||||
void *hDimension, Nd4jLong *hDimensionShape,
|
|
||||||
void *dDimension, Nd4jLong *dDimensionShape,
|
|
||||||
Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets,
|
Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets,
|
||||||
Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets);
|
Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets);
|
||||||
|
|
||||||
|
@ -405,22 +355,16 @@ ND4J_EXPORT void execReduce3All(Nd4jPointer *extraPointers,
|
||||||
*/
|
*/
|
||||||
ND4J_EXPORT void execScalar(Nd4jPointer *extraPointers,
|
ND4J_EXPORT void execScalar(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
OpaqueDataBuffer *dbScalar, Nd4jLong *hSscalarShapeInfo, Nd4jLong *dSscalarShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
|
||||||
void *hScalar, Nd4jLong *hSscalarShapeInfo,
|
|
||||||
void *dScalar, Nd4jLong *dSscalarShapeInfo,
|
|
||||||
void *extraParams);
|
void *extraParams);
|
||||||
|
|
||||||
ND4J_EXPORT void execScalarBool(Nd4jPointer *extraPointers,
|
ND4J_EXPORT void execScalarBool(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
OpaqueDataBuffer *dbScalar, Nd4jLong *hSscalarShapeInfo, Nd4jLong *dSscalarShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
|
||||||
void *hScalar, Nd4jLong *hSscalarShapeInfo,
|
|
||||||
void *dScalar, Nd4jLong *dSscalarShapeInfo,
|
|
||||||
void *extraParams);
|
void *extraParams);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -432,11 +376,9 @@ ND4J_EXPORT void execScalarBool(Nd4jPointer *extraPointers,
|
||||||
*/
|
*/
|
||||||
ND4J_EXPORT void execSummaryStatsScalar(Nd4jPointer *extraPointers,
|
ND4J_EXPORT void execSummaryStatsScalar(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
|
||||||
bool biasCorrected);
|
bool biasCorrected);
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
@ -449,11 +391,9 @@ ND4J_EXPORT void execSummaryStatsScalar(Nd4jPointer *extraPointers,
|
||||||
*/
|
*/
|
||||||
ND4J_EXPORT void execSummaryStats(Nd4jPointer *extraPointers,
|
ND4J_EXPORT void execSummaryStats(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
|
||||||
bool biasCorrected);
|
bool biasCorrected);
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
@ -468,13 +408,10 @@ ND4J_EXPORT void execSummaryStats(Nd4jPointer *extraPointers,
|
||||||
*/
|
*/
|
||||||
ND4J_EXPORT void execSummaryStatsTad(Nd4jPointer *extraPointers,
|
ND4J_EXPORT void execSummaryStatsTad(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape,
|
||||||
void *hDimension, Nd4jLong *hDimensionShape,
|
|
||||||
void *dDimension, Nd4jLong *dDimensionShape,
|
|
||||||
bool biasCorrected,
|
bool biasCorrected,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
|
||||||
|
|
||||||
|
@ -490,42 +427,32 @@ ND4J_EXPORT void execSummaryStatsTad(Nd4jPointer *extraPointers,
|
||||||
*/
|
*/
|
||||||
ND4J_EXPORT void execTransformFloat(Nd4jPointer *extraPointers,
|
ND4J_EXPORT void execTransformFloat(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
|
||||||
void *extraParams);
|
void *extraParams);
|
||||||
|
|
||||||
ND4J_EXPORT void execTransformSame(Nd4jPointer *extraPointers,
|
ND4J_EXPORT void execTransformSame(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
|
||||||
void *extraParams);
|
void *extraParams);
|
||||||
|
|
||||||
ND4J_EXPORT void execTransformBool(Nd4jPointer *extraPointers,
|
ND4J_EXPORT void execTransformBool(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
|
||||||
void *extraParams);
|
void *extraParams);
|
||||||
|
|
||||||
ND4J_EXPORT void execTransformAny(Nd4jPointer *extraPointers,
|
ND4J_EXPORT void execTransformAny(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
|
||||||
void *extraParams);
|
void *extraParams);
|
||||||
|
|
||||||
ND4J_EXPORT void execTransformStrict(Nd4jPointer *extraPointers,
|
ND4J_EXPORT void execTransformStrict(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
|
||||||
void *extraParams);
|
void *extraParams);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -543,29 +470,21 @@ ND4J_EXPORT void execTransformStrict(Nd4jPointer *extraPointers,
|
||||||
*/
|
*/
|
||||||
ND4J_EXPORT void execScalarTad(Nd4jPointer *extraPointers,
|
ND4J_EXPORT void execScalarTad(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
OpaqueDataBuffer *dbScalars, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
|
||||||
void *hScalars, Nd4jLong *hScalarShapeInfo,
|
|
||||||
void *dScalars, Nd4jLong *dScalarShapeInfo,
|
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *hDimension, Nd4jLong *hDimensionShape,
|
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape,
|
||||||
void *dDimension, Nd4jLong *dDimensionShape,
|
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
||||||
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
||||||
|
|
||||||
ND4J_EXPORT void execScalarBoolTad(Nd4jPointer *extraPointers,
|
ND4J_EXPORT void execScalarBoolTad(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
OpaqueDataBuffer *dbScalars, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
|
||||||
void *hScalars, Nd4jLong *hScalarShapeInfo,
|
|
||||||
void *dScalars, Nd4jLong *dScalarShapeInfo,
|
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *hDimension, Nd4jLong *hDimensionShape,
|
OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape,
|
||||||
void *dDimension, Nd4jLong *dDimensionShape,
|
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
||||||
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
||||||
|
|
||||||
|
@ -904,10 +823,8 @@ ND4J_EXPORT void deleteTadPack(OpaqueTadPack* ptr);
|
||||||
* @param zTadOffsets
|
* @param zTadOffsets
|
||||||
*/
|
*/
|
||||||
ND4J_EXPORT void pullRows(Nd4jPointer *extraPointers,
|
ND4J_EXPORT void pullRows(Nd4jPointer *extraPointers,
|
||||||
void *x, Nd4jLong *xShapeInfo,
|
OpaqueDataBuffer *dbX, Nd4jLong *xShapeInfo, Nd4jLong *dxShapeInfo,
|
||||||
void *dx, Nd4jLong *dxShapeInfo,
|
OpaqueDataBuffer *dbZ, Nd4jLong *zShapeInfo, Nd4jLong *dzShapeInfo,
|
||||||
void *z, Nd4jLong *zShapeInfo,
|
|
||||||
void *dz, Nd4jLong *dzShapeInfo,
|
|
||||||
Nd4jLong n,
|
Nd4jLong n,
|
||||||
Nd4jLong *indexes,
|
Nd4jLong *indexes,
|
||||||
Nd4jLong *tadShapeInfo,
|
Nd4jLong *tadShapeInfo,
|
||||||
|
@ -1086,8 +1003,7 @@ ND4J_EXPORT void execAggregateBatch(Nd4jPointer *extraPointers,
|
||||||
ND4J_EXPORT void execRandom(Nd4jPointer *extraPointers,
|
ND4J_EXPORT void execRandom(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Nd4jPointer state,
|
Nd4jPointer state,
|
||||||
void *hZ, Nd4jLong *hZShapeBuffer,
|
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeBuffer, Nd4jLong *dZShapeBuffer,
|
||||||
void *dZ, Nd4jLong *dZShapeBuffer,
|
|
||||||
void *extraArguments);
|
void *extraArguments);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -1106,12 +1022,9 @@ ND4J_EXPORT void execRandom(Nd4jPointer *extraPointers,
|
||||||
ND4J_EXPORT void execRandom3(Nd4jPointer *extraPointers,
|
ND4J_EXPORT void execRandom3(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Nd4jPointer state,
|
Nd4jPointer state,
|
||||||
void *hX, Nd4jLong *hXShapeBuffer,
|
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeBuffer, Nd4jLong *dXShapeBuffer,
|
||||||
void *dX, Nd4jLong *dXShapeBuffer,
|
OpaqueDataBuffer *dbY, Nd4jLong *hYShapeBuffer, Nd4jLong *dYShapeBuffer,
|
||||||
void *hY, Nd4jLong *hYShapeBuffer,
|
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeBuffer, Nd4jLong *dZShapeBuffer,
|
||||||
void *dY, Nd4jLong *dYShapeBuffer,
|
|
||||||
void *hZ, Nd4jLong *hZShapeBuffer,
|
|
||||||
void *dZ, Nd4jLong *dZShapeBuffer,
|
|
||||||
void *extraArguments);
|
void *extraArguments);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -1128,10 +1041,8 @@ ND4J_EXPORT void execRandom3(Nd4jPointer *extraPointers,
|
||||||
ND4J_EXPORT void execRandom2(Nd4jPointer *extraPointers,
|
ND4J_EXPORT void execRandom2(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Nd4jPointer state,
|
Nd4jPointer state,
|
||||||
void *hX, Nd4jLong *hXShapeBuffer,
|
OpaqueDataBuffer *dbX, Nd4jLong *hXShapeBuffer, Nd4jLong *dXShapeBuffer,
|
||||||
void *dX, Nd4jLong *dXShapeBuffer,
|
OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeBuffer, Nd4jLong *dZShapeBuffer,
|
||||||
void *hZ, Nd4jLong *hZShapeBuffer,
|
|
||||||
void *dZ, Nd4jLong *dZShapeBuffer,
|
|
||||||
void *extraArguments);
|
void *extraArguments);
|
||||||
|
|
||||||
|
|
||||||
|
@ -1174,52 +1085,6 @@ ND4J_EXPORT void reSeedBuffer(Nd4jPointer *extraPointers,
|
||||||
*/
|
*/
|
||||||
ND4J_EXPORT void destroyRandom(Nd4jPointer ptrRandom);
|
ND4J_EXPORT void destroyRandom(Nd4jPointer ptrRandom);
|
||||||
|
|
||||||
/**
|
|
||||||
* Grid operations
|
|
||||||
*/
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param extras
|
|
||||||
* @param opTypeA
|
|
||||||
* @param opNumA
|
|
||||||
* @param opTypeB
|
|
||||||
* @param opNumB
|
|
||||||
* @param N
|
|
||||||
* @param dx
|
|
||||||
* @param xShapeInfo
|
|
||||||
* @param dy
|
|
||||||
* @param yShapeInfo
|
|
||||||
* @param dz
|
|
||||||
* @param zShapeInfo
|
|
||||||
* @param extraA
|
|
||||||
* @param extraB
|
|
||||||
* @param scalarA
|
|
||||||
* @param scalarB
|
|
||||||
*/
|
|
||||||
/*
|
|
||||||
ND4J_EXPORT void execMetaPredicateShape(Nd4jPointer *extras,
|
|
||||||
const int opTypeA,
|
|
||||||
const int opNumA,
|
|
||||||
const int opTypeB,
|
|
||||||
const int opNumB,
|
|
||||||
Nd4jLong N,
|
|
||||||
void *hX, Nd4jLong *hXShapeBuffer,
|
|
||||||
void *dX, Nd4jLong *dXShapeBuffer,
|
|
||||||
void *hY, Nd4jLong *hYShapeBuffer,
|
|
||||||
void *dY, Nd4jLong *dYShapeBuffer,
|
|
||||||
void *hZ, Nd4jLong *hZShapeBuffer,
|
|
||||||
void *dZ, Nd4jLong *dZShapeBuffer,
|
|
||||||
void *extraA,
|
|
||||||
void *extraB,
|
|
||||||
double scalarA,
|
|
||||||
double scalarB);
|
|
||||||
|
|
||||||
*/
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -1561,11 +1426,10 @@ ND4J_EXPORT Nd4jPointer pointerForAddress(Nd4jLong address);
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
ND4J_EXPORT void tear(Nd4jPointer *extraPointers,
|
ND4J_EXPORT void tear(Nd4jPointer *extraPointers,
|
||||||
void *x, Nd4jLong *xShapeInfo,
|
OpaqueDataBuffer *dbX, Nd4jLong *xShapeInfo, Nd4jLong *dxShapeInfo,
|
||||||
void *dx, Nd4jLong *dxShapeInfo,
|
Nd4jPointer *targets, Nd4jLong *zShapeInfo,
|
||||||
Nd4jPointer *targets, Nd4jLong *zShapeInfo,
|
Nd4jLong *tadShapeInfo,
|
||||||
Nd4jLong *tadShapeInfo,
|
Nd4jLong *tadOffsets);
|
||||||
Nd4jLong *tadOffsets);
|
|
||||||
|
|
||||||
ND4J_EXPORT Nd4jLong encodeBitmap(Nd4jPointer *extraPointers, void *dx, Nd4jLong *xShapeInfo, Nd4jLong N, int *dz, float threshold);
|
ND4J_EXPORT Nd4jLong encodeBitmap(Nd4jPointer *extraPointers, void *dx, Nd4jLong *xShapeInfo, Nd4jLong N, int *dz, float threshold);
|
||||||
ND4J_EXPORT void decodeBitmap(Nd4jPointer *extraPointers, void *dx, Nd4jLong N, void *dz, Nd4jLong *zShapeInfo);
|
ND4J_EXPORT void decodeBitmap(Nd4jPointer *extraPointers, void *dx, Nd4jLong N, void *dz, Nd4jLong *zShapeInfo);
|
||||||
|
@ -1739,6 +1603,8 @@ ND4J_EXPORT void markGraphContextInplace(OpaqueContext* ptr, bool reallyInplace)
|
||||||
ND4J_EXPORT void setGraphContextCudaContext(OpaqueContext* ptr, void *stream, void *reductionPointer, void *allocationPointer);
|
ND4J_EXPORT void setGraphContextCudaContext(OpaqueContext* ptr, void *stream, void *reductionPointer, void *allocationPointer);
|
||||||
ND4J_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
|
ND4J_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
|
||||||
ND4J_EXPORT void setGraphContextOutputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
|
ND4J_EXPORT void setGraphContextOutputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
|
||||||
|
ND4J_EXPORT void setGraphContextInputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo);
|
||||||
|
ND4J_EXPORT void setGraphContextOutputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo);
|
||||||
ND4J_EXPORT void setGraphContextTArguments(OpaqueContext* ptr, double *arguments, int numberOfArguments);
|
ND4J_EXPORT void setGraphContextTArguments(OpaqueContext* ptr, double *arguments, int numberOfArguments);
|
||||||
ND4J_EXPORT void setGraphContextIArguments(OpaqueContext* ptr, Nd4jLong *arguments, int numberOfArguments);
|
ND4J_EXPORT void setGraphContextIArguments(OpaqueContext* ptr, Nd4jLong *arguments, int numberOfArguments);
|
||||||
ND4J_EXPORT void setGraphContextBArguments(OpaqueContext* ptr, bool *arguments, int numberOfArguments);
|
ND4J_EXPORT void setGraphContextBArguments(OpaqueContext* ptr, bool *arguments, int numberOfArguments);
|
||||||
|
@ -1766,6 +1632,28 @@ ND4J_EXPORT Nd4jPointer lcCopyStream(OpaqueLaunchContext* lc);
|
||||||
ND4J_EXPORT Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc);
|
ND4J_EXPORT Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc);
|
||||||
ND4J_EXPORT Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc);
|
ND4J_EXPORT Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc);
|
||||||
|
|
||||||
|
ND4J_EXPORT OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth);
|
||||||
|
ND4J_EXPORT OpaqueDataBuffer* dbCreateView(OpaqueDataBuffer *dataBuffer, Nd4jLong length, Nd4jLong offset);
|
||||||
|
ND4J_EXPORT Nd4jPointer dbPrimaryBuffer(OpaqueDataBuffer *dataBuffer);
|
||||||
|
ND4J_EXPORT Nd4jPointer dbSpecialBuffer(OpaqueDataBuffer *dataBuffer);
|
||||||
|
ND4J_EXPORT void dbExpandBuffer(OpaqueDataBuffer *dataBuffer, Nd4jLong elements);
|
||||||
|
ND4J_EXPORT void dbAllocatePrimaryBuffer(OpaqueDataBuffer *dataBuffer);
|
||||||
|
ND4J_EXPORT void dbAllocateSpecialBuffer(OpaqueDataBuffer *dataBuffer);
|
||||||
|
ND4J_EXPORT void dbSetPrimaryBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer primaryBuffer, Nd4jLong numBytes);
|
||||||
|
ND4J_EXPORT void dbSetSpecialBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer specialBuffer, Nd4jLong numBytes);
|
||||||
|
ND4J_EXPORT void dbSyncToSpecial(OpaqueDataBuffer *dataBuffer);
|
||||||
|
ND4J_EXPORT void dbSyncToPrimary(OpaqueDataBuffer *dataBuffer);
|
||||||
|
ND4J_EXPORT int dbLocality(OpaqueDataBuffer *dataBuffer);
|
||||||
|
ND4J_EXPORT int dbDeviceId(OpaqueDataBuffer *dataBuffer);
|
||||||
|
ND4J_EXPORT void dbSetDeviceId(OpaqueDataBuffer *dataBuffer, int deviceId);
|
||||||
|
ND4J_EXPORT void dbTickHostRead(OpaqueDataBuffer *dataBuffer);
|
||||||
|
ND4J_EXPORT void dbTickHostWrite(OpaqueDataBuffer *dataBuffer);
|
||||||
|
ND4J_EXPORT void dbTickDeviceRead(OpaqueDataBuffer *dataBuffer);
|
||||||
|
ND4J_EXPORT void dbTickDeviceWrite(OpaqueDataBuffer *dataBuffer);
|
||||||
|
ND4J_EXPORT void dbClose(OpaqueDataBuffer *dataBuffer);
|
||||||
|
ND4J_EXPORT void deleteDataBuffer(OpaqueDataBuffer *dataBuffer);
|
||||||
|
ND4J_EXPORT void dbExpand(OpaqueDataBuffer *dataBuffer, Nd4jLong elements);
|
||||||
|
|
||||||
|
|
||||||
ND4J_EXPORT int binaryLevel();
|
ND4J_EXPORT int binaryLevel();
|
||||||
ND4J_EXPORT int optimalLevel();
|
ND4J_EXPORT int optimalLevel();
|
||||||
|
|
|
@ -184,16 +184,16 @@ void NDArray::synchronize(const char* msg) const {
|
||||||
// no-op
|
// no-op
|
||||||
}
|
}
|
||||||
|
|
||||||
void NDArray::prepareSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) {
|
void NDArray::prepareSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables) {
|
||||||
// no-op
|
// no-op
|
||||||
}
|
}
|
||||||
void NDArray::registerSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) {
|
void NDArray::registerSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList) {
|
||||||
// no-op
|
// no-op
|
||||||
}
|
}
|
||||||
void NDArray::preparePrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) {
|
void NDArray::preparePrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables) {
|
||||||
// no-op
|
// no-op
|
||||||
}
|
}
|
||||||
void NDArray::registerPrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) {
|
void NDArray::registerPrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList) {
|
||||||
// no-op
|
// no-op
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -236,7 +236,7 @@ void NDArray::synchronize(const char* msg) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::prepareSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) {
|
void NDArray::prepareSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables) {
|
||||||
|
|
||||||
for (const auto& a : readList)
|
for (const auto& a : readList)
|
||||||
if(a != nullptr)
|
if(a != nullptr)
|
||||||
|
@ -252,7 +252,7 @@ void NDArray::prepareSpecialUse(const std::initializer_list<const NDArray*>& wri
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::registerSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) {
|
void NDArray::registerSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList) {
|
||||||
|
|
||||||
for (const auto& p : readList)
|
for (const auto& p : readList)
|
||||||
if(p != nullptr)
|
if(p != nullptr)
|
||||||
|
@ -264,7 +264,7 @@ void NDArray::registerSpecialUse(const std::initializer_list<const NDArray*>& wr
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::preparePrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) {
|
void NDArray::preparePrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables) {
|
||||||
|
|
||||||
for (const auto& a : readList)
|
for (const auto& a : readList)
|
||||||
if(a != nullptr)
|
if(a != nullptr)
|
||||||
|
@ -280,7 +280,7 @@ void NDArray::preparePrimaryUse(const std::initializer_list<const NDArray*>& wri
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::registerPrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) {
|
void NDArray::registerPrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList) {
|
||||||
|
|
||||||
for (const auto& p : readList)
|
for (const auto& p : readList)
|
||||||
if(p != nullptr)
|
if(p != nullptr)
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -34,10 +34,12 @@
|
||||||
#define ARRAY_SPARSE 2
|
#define ARRAY_SPARSE 2
|
||||||
#define ARRAY_COMPRESSED 4
|
#define ARRAY_COMPRESSED 4
|
||||||
#define ARRAY_EMPTY 8
|
#define ARRAY_EMPTY 8
|
||||||
|
#define ARRAY_RAGGED 16
|
||||||
|
|
||||||
#define ARRAY_CSR 16
|
|
||||||
#define ARRAY_CSC 32
|
#define ARRAY_CSR 32
|
||||||
#define ARRAY_COO 64
|
#define ARRAY_CSC 64
|
||||||
|
#define ARRAY_COO 128
|
||||||
|
|
||||||
// complex values
|
// complex values
|
||||||
#define ARRAY_COMPLEX 512
|
#define ARRAY_COMPLEX 512
|
||||||
|
@ -72,8 +74,10 @@
|
||||||
// boolean values
|
// boolean values
|
||||||
#define ARRAY_BOOL 524288
|
#define ARRAY_BOOL 524288
|
||||||
|
|
||||||
// utf-8 values
|
// UTF values
|
||||||
#define ARRAY_STRING 1048576
|
#define ARRAY_UTF8 1048576
|
||||||
|
#define ARRAY_UTF16 4194304
|
||||||
|
#define ARRAY_UTF32 16777216
|
||||||
|
|
||||||
// flag for extras
|
// flag for extras
|
||||||
#define ARRAY_EXTRAS 2097152
|
#define ARRAY_EXTRAS 2097152
|
||||||
|
@ -173,8 +177,12 @@ namespace nd4j {
|
||||||
return nd4j::DataType ::UINT32;
|
return nd4j::DataType ::UINT32;
|
||||||
else if (hasPropertyBitSet(shapeInfo, ARRAY_LONG))
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_LONG))
|
||||||
return nd4j::DataType ::UINT64;
|
return nd4j::DataType ::UINT64;
|
||||||
else if (hasPropertyBitSet(shapeInfo, ARRAY_STRING))
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF8))
|
||||||
return nd4j::DataType ::UTF8;
|
return nd4j::DataType ::UTF8;
|
||||||
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF16))
|
||||||
|
return nd4j::DataType ::UTF16;
|
||||||
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF32))
|
||||||
|
return nd4j::DataType ::UTF32;
|
||||||
else {
|
else {
|
||||||
//shape::printShapeInfoLinear("Bad unsigned datatype (not)stored in shape", const_cast<Nd4jLong*>(shapeInfo));
|
//shape::printShapeInfoLinear("Bad unsigned datatype (not)stored in shape", const_cast<Nd4jLong*>(shapeInfo));
|
||||||
#ifndef __CUDA_ARCH__
|
#ifndef __CUDA_ARCH__
|
||||||
|
@ -190,8 +198,12 @@ namespace nd4j {
|
||||||
return nd4j::DataType::INT32;
|
return nd4j::DataType::INT32;
|
||||||
else if (hasPropertyBitSet(shapeInfo, ARRAY_LONG))
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_LONG))
|
||||||
return nd4j::DataType::INT64;
|
return nd4j::DataType::INT64;
|
||||||
else if (hasPropertyBitSet(shapeInfo, ARRAY_STRING))
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF8))
|
||||||
return nd4j::DataType::UTF8;
|
return nd4j::DataType::UTF8;
|
||||||
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF16))
|
||||||
|
return nd4j::DataType::UTF16;
|
||||||
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF32))
|
||||||
|
return nd4j::DataType::UTF32;
|
||||||
else {
|
else {
|
||||||
//shape::printShapeInfoLinear("Bad signed datatype (not)stored in shape", const_cast<Nd4jLong*>(shapeInfo));
|
//shape::printShapeInfoLinear("Bad signed datatype (not)stored in shape", const_cast<Nd4jLong*>(shapeInfo));
|
||||||
#ifndef __CUDA_ARCH__
|
#ifndef __CUDA_ARCH__
|
||||||
|
@ -224,6 +236,8 @@ namespace nd4j {
|
||||||
return ArrayType::COMPRESSED;
|
return ArrayType::COMPRESSED;
|
||||||
else if (hasPropertyBitSet(shapeInfo, ARRAY_EMPTY))
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_EMPTY))
|
||||||
return ArrayType::EMPTY;
|
return ArrayType::EMPTY;
|
||||||
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_RAGGED))
|
||||||
|
return ArrayType::RAGGED;
|
||||||
else // by default we return DENSE type here
|
else // by default we return DENSE type here
|
||||||
return ArrayType::DENSE;
|
return ArrayType::DENSE;
|
||||||
}
|
}
|
||||||
|
@ -333,7 +347,13 @@ namespace nd4j {
|
||||||
setPropertyBit(shapeInfo, ARRAY_LONG);
|
setPropertyBit(shapeInfo, ARRAY_LONG);
|
||||||
break;
|
break;
|
||||||
case nd4j::DataType::UTF8:
|
case nd4j::DataType::UTF8:
|
||||||
setPropertyBit(shapeInfo, ARRAY_STRING);
|
setPropertyBit(shapeInfo, ARRAY_UTF8);
|
||||||
|
break;
|
||||||
|
case nd4j::DataType::UTF16:
|
||||||
|
setPropertyBit(shapeInfo, ARRAY_UTF16);
|
||||||
|
break;
|
||||||
|
case nd4j::DataType::UTF32:
|
||||||
|
setPropertyBit(shapeInfo, ARRAY_UTF32);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
#ifndef __CUDA_ARCH__
|
#ifndef __CUDA_ARCH__
|
||||||
|
|
|
@ -27,6 +27,7 @@ namespace nd4j {
|
||||||
SPARSE = 2,
|
SPARSE = 2,
|
||||||
COMPRESSED = 3,
|
COMPRESSED = 3,
|
||||||
EMPTY = 4,
|
EMPTY = 4,
|
||||||
|
RAGGED = 5,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -36,13 +36,14 @@ class ND4J_EXPORT DataBuffer {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
void* _primaryBuffer;
|
void* _primaryBuffer = nullptr;
|
||||||
void* _specialBuffer;
|
void* _specialBuffer = nullptr;
|
||||||
size_t _lenInBytes;
|
size_t _lenInBytes = 0;
|
||||||
DataType _dataType;
|
DataType _dataType;
|
||||||
memory::Workspace* _workspace;
|
memory::Workspace* _workspace = nullptr;
|
||||||
bool _isOwnerPrimary;
|
bool _isOwnerPrimary;
|
||||||
bool _isOwnerSpecial;
|
bool _isOwnerSpecial;
|
||||||
|
std::atomic<int> _deviceId;
|
||||||
|
|
||||||
#ifdef __CUDABLAS__
|
#ifdef __CUDABLAS__
|
||||||
mutable std::atomic<Nd4jLong> _counter;
|
mutable std::atomic<Nd4jLong> _counter;
|
||||||
|
@ -52,51 +53,52 @@ class ND4J_EXPORT DataBuffer {
|
||||||
mutable std::atomic<Nd4jLong> _readSpecial;
|
mutable std::atomic<Nd4jLong> _readSpecial;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void setCountersToZero();
|
void setCountersToZero();
|
||||||
void copyCounters(const DataBuffer& other);
|
void copyCounters(const DataBuffer& other);
|
||||||
void deleteSpecial();
|
void deleteSpecial();
|
||||||
FORCEINLINE void deletePrimary();
|
void deletePrimary();
|
||||||
FORCEINLINE void deleteBuffers();
|
void deleteBuffers();
|
||||||
FORCEINLINE void setAllocFlags(const bool isOwnerPrimary, const bool isOwnerSpecial = false);
|
void setAllocFlags(const bool isOwnerPrimary, const bool isOwnerSpecial = false);
|
||||||
void allocateBuffers(const bool allocBoth = false);
|
void allocateBuffers(const bool allocBoth = false);
|
||||||
void setSpecial(void* special, const bool isOwnerSpecial);
|
void setSpecial(void* special, const bool isOwnerSpecial);
|
||||||
void copyBufferFromHost(const void* hostBuffer, size_t sizeToCopyinBytes = 0, const Nd4jLong offsetThis = 0, const Nd4jLong offsetHostBuffer = 0);
|
void copyBufferFromHost(const void* hostBuffer, size_t sizeToCopyinBytes = 0, const Nd4jLong offsetThis = 0, const Nd4jLong offsetHostBuffer = 0);
|
||||||
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
FORCEINLINE DataBuffer(void* primary, void* special,
|
DataBuffer(void* primary, void* special,
|
||||||
const size_t lenInBytes, const DataType dataType,
|
const size_t lenInBytes, const DataType dataType,
|
||||||
const bool isOwnerPrimary = false, const bool isOwnerSpecial = false,
|
const bool isOwnerPrimary = false, const bool isOwnerSpecial = false,
|
||||||
memory::Workspace* workspace = nullptr);
|
memory::Workspace* workspace = nullptr);
|
||||||
|
|
||||||
FORCEINLINE DataBuffer(void* primary,
|
DataBuffer(void* primary,
|
||||||
const size_t lenInBytes, const DataType dataType,
|
const size_t lenInBytes, const DataType dataType,
|
||||||
const bool isOwnerPrimary = false,
|
const bool isOwnerPrimary = false,
|
||||||
memory::Workspace* workspace = nullptr);
|
memory::Workspace* workspace = nullptr);
|
||||||
|
|
||||||
FORCEINLINE DataBuffer(const void* hostBuffer, // copies data from hostBuffer to own memory buffer
|
DataBuffer(const void* hostBuffer, // copies data from hostBuffer to own memory buffer
|
||||||
const DataType dataType, const size_t lenInBytes,
|
const DataType dataType, const size_t lenInBytes,
|
||||||
memory::Workspace* workspace = nullptr);
|
memory::Workspace* workspace = nullptr);
|
||||||
|
|
||||||
FORCEINLINE DataBuffer(const size_t lenInBytes, const DataType dataType, memory::Workspace* workspace = nullptr, const bool allocBoth = false);
|
DataBuffer(const size_t lenInBytes, const DataType dataType, memory::Workspace* workspace = nullptr, const bool allocBoth = false);
|
||||||
|
|
||||||
FORCEINLINE DataBuffer(const DataBuffer& other);
|
DataBuffer(const DataBuffer& other);
|
||||||
FORCEINLINE DataBuffer(DataBuffer&& other);
|
DataBuffer(DataBuffer&& other);
|
||||||
FORCEINLINE explicit DataBuffer();
|
explicit DataBuffer();
|
||||||
FORCEINLINE ~DataBuffer();
|
~DataBuffer();
|
||||||
|
|
||||||
FORCEINLINE DataBuffer& operator=(const DataBuffer& other);
|
DataBuffer& operator=(const DataBuffer& other);
|
||||||
FORCEINLINE DataBuffer& operator=(DataBuffer&& other) noexcept;
|
DataBuffer& operator=(DataBuffer&& other) noexcept;
|
||||||
|
|
||||||
FORCEINLINE DataType getDataType();
|
DataType getDataType();
|
||||||
FORCEINLINE size_t getLenInBytes() const;
|
void setDataType(DataType dataType);
|
||||||
|
size_t getLenInBytes() const;
|
||||||
|
|
||||||
FORCEINLINE void* primary();
|
void* primary();
|
||||||
FORCEINLINE void* special();
|
void* special();
|
||||||
|
|
||||||
FORCEINLINE void allocatePrimary();
|
void allocatePrimary();
|
||||||
void allocateSpecial();
|
void allocateSpecial();
|
||||||
|
|
||||||
void writePrimary() const;
|
void writePrimary() const;
|
||||||
void writeSpecial() const;
|
void writeSpecial() const;
|
||||||
|
@ -105,6 +107,10 @@ class ND4J_EXPORT DataBuffer {
|
||||||
bool isPrimaryActual() const;
|
bool isPrimaryActual() const;
|
||||||
bool isSpecialActual() const;
|
bool isSpecialActual() const;
|
||||||
|
|
||||||
|
void expand(const uint64_t size);
|
||||||
|
|
||||||
|
int deviceId() const;
|
||||||
|
void setDeviceId(int deviceId);
|
||||||
void migrate();
|
void migrate();
|
||||||
|
|
||||||
template <typename T> FORCEINLINE T* primaryAsT();
|
template <typename T> FORCEINLINE T* primaryAsT();
|
||||||
|
@ -118,256 +124,28 @@ class ND4J_EXPORT DataBuffer {
|
||||||
void copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes = 0, const Nd4jLong offsetThis = 0, const Nd4jLong offsetOther = 0);
|
void copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes = 0, const Nd4jLong offsetThis = 0, const Nd4jLong offsetOther = 0);
|
||||||
|
|
||||||
static void memcpy(const DataBuffer &dst, const DataBuffer &src);
|
static void memcpy(const DataBuffer &dst, const DataBuffer &src);
|
||||||
|
|
||||||
|
void setPrimaryBuffer(void *buffer, size_t length);
|
||||||
|
void setSpecialBuffer(void *buffer, size_t length);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method deletes buffers, if we're owners
|
||||||
|
*/
|
||||||
|
void close();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
///// IMLEMENTATION OF INLINE METHODS /////
|
///// IMLEMENTATION OF INLINE METHODS /////
|
||||||
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
// default constructor
|
template <typename T>
|
||||||
DataBuffer::DataBuffer() {
|
T* DataBuffer::primaryAsT() {
|
||||||
|
return reinterpret_cast<T*>(_primaryBuffer);
|
||||||
_primaryBuffer = nullptr;
|
|
||||||
_specialBuffer = nullptr;
|
|
||||||
_lenInBytes = 0;
|
|
||||||
_dataType = INT8;
|
|
||||||
_workspace = nullptr;
|
|
||||||
_isOwnerPrimary = false;
|
|
||||||
_isOwnerSpecial = false;
|
|
||||||
|
|
||||||
setCountersToZero();
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
|
||||||
// copy constructor
|
|
||||||
DataBuffer::DataBuffer(const DataBuffer &other) {
|
|
||||||
|
|
||||||
throw std::runtime_error("DataBuffer copy constructor: we don't expect using of this constructor!");
|
|
||||||
|
|
||||||
_lenInBytes = other._lenInBytes;
|
|
||||||
_dataType = other._dataType;
|
|
||||||
_workspace = other._workspace;
|
|
||||||
|
|
||||||
_primaryBuffer = nullptr;
|
|
||||||
_specialBuffer = nullptr;
|
|
||||||
|
|
||||||
setCountersToZero();
|
|
||||||
|
|
||||||
allocateBuffers();
|
|
||||||
copyBufferFrom(other);
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
|
||||||
DataBuffer::DataBuffer(void* primary, void* special,
|
|
||||||
const size_t lenInBytes, const DataType dataType,
|
|
||||||
const bool isOwnerPrimary, const bool isOwnerSpecial,
|
|
||||||
memory::Workspace* workspace) {
|
|
||||||
|
|
||||||
if (primary == nullptr && special == nullptr)
|
|
||||||
throw std::runtime_error("DataBuffer constructor: can't be initialized with both nullptr buffers !");
|
|
||||||
|
|
||||||
_primaryBuffer = primary;
|
|
||||||
_specialBuffer = special;
|
|
||||||
_lenInBytes = lenInBytes;
|
|
||||||
_dataType = dataType;
|
|
||||||
_workspace = workspace;
|
|
||||||
_isOwnerPrimary = isOwnerPrimary;
|
|
||||||
_isOwnerSpecial = isOwnerSpecial;
|
|
||||||
|
|
||||||
setCountersToZero();
|
|
||||||
|
|
||||||
if(primary != nullptr)
|
|
||||||
readPrimary();
|
|
||||||
if(special != nullptr)
|
|
||||||
readSpecial();
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
|
||||||
DataBuffer::DataBuffer(void* primary, const size_t lenInBytes, const DataType dataType, const bool isOwnerPrimary, memory::Workspace* workspace):
|
|
||||||
DataBuffer(primary, nullptr, lenInBytes, dataType, isOwnerPrimary, false, workspace) {
|
|
||||||
|
|
||||||
syncToSpecial(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
|
||||||
// copies data from hostBuffer to own memory buffer
|
|
||||||
DataBuffer::DataBuffer(const void* hostBuffer, const DataType dataType, const size_t lenInBytes, memory::Workspace* workspace) {
|
|
||||||
|
|
||||||
if (hostBuffer == nullptr)
|
|
||||||
throw std::runtime_error("DataBuffer constructor: can't be initialized with nullptr host buffer !");
|
|
||||||
if (lenInBytes == 0)
|
|
||||||
throw std::runtime_error("DataBuffer constructor: can't be initialized with zero length !");
|
|
||||||
|
|
||||||
_primaryBuffer = nullptr;
|
|
||||||
_specialBuffer = nullptr;
|
|
||||||
_lenInBytes = lenInBytes;
|
|
||||||
_dataType = dataType;
|
|
||||||
_workspace = workspace;
|
|
||||||
|
|
||||||
setCountersToZero();
|
|
||||||
|
|
||||||
allocateBuffers();
|
|
||||||
|
|
||||||
copyBufferFromHost(hostBuffer, lenInBytes);
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
|
||||||
DataBuffer::DataBuffer(const size_t lenInBytes, const DataType dataType, memory::Workspace* workspace, const bool allocBoth) {
|
|
||||||
|
|
||||||
_dataType = dataType;
|
|
||||||
_workspace = workspace;
|
|
||||||
_lenInBytes = lenInBytes;
|
|
||||||
|
|
||||||
_primaryBuffer = nullptr;
|
|
||||||
_specialBuffer = nullptr;
|
|
||||||
|
|
||||||
setCountersToZero();
|
|
||||||
|
|
||||||
if(lenInBytes != 0) {
|
|
||||||
allocateBuffers(allocBoth);
|
|
||||||
writeSpecial();
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
// move constructor
|
template <typename T>
|
||||||
DataBuffer::DataBuffer(DataBuffer&& other) {
|
T* DataBuffer::specialAsT() {
|
||||||
|
return reinterpret_cast<T*>(_specialBuffer);
|
||||||
_primaryBuffer = other._primaryBuffer;
|
|
||||||
_specialBuffer = other._specialBuffer;
|
|
||||||
_lenInBytes = other._lenInBytes;
|
|
||||||
_dataType = other._dataType;
|
|
||||||
_workspace = other._workspace;
|
|
||||||
_isOwnerPrimary = other._isOwnerPrimary;
|
|
||||||
_isOwnerSpecial = other._isOwnerSpecial;
|
|
||||||
|
|
||||||
copyCounters(other);
|
|
||||||
|
|
||||||
other._primaryBuffer = other._specialBuffer = nullptr;
|
|
||||||
other.setAllocFlags(false, false);
|
|
||||||
other._lenInBytes = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
|
||||||
// assignment operator
|
|
||||||
DataBuffer& DataBuffer::operator=(const DataBuffer& other) {
|
|
||||||
|
|
||||||
if (this == &other)
|
|
||||||
return *this;
|
|
||||||
|
|
||||||
deleteBuffers();
|
|
||||||
|
|
||||||
_lenInBytes = other._lenInBytes;
|
|
||||||
_dataType = other._dataType;
|
|
||||||
_workspace = other._workspace;
|
|
||||||
|
|
||||||
allocateBuffers();
|
|
||||||
copyBufferFrom(other);
|
|
||||||
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
|
||||||
// move assignment operator
|
|
||||||
DataBuffer& DataBuffer::operator=(DataBuffer&& other) noexcept {
|
|
||||||
|
|
||||||
if (this == &other)
|
|
||||||
return *this;
|
|
||||||
|
|
||||||
deleteBuffers();
|
|
||||||
|
|
||||||
_primaryBuffer = other._primaryBuffer;
|
|
||||||
_specialBuffer = other._specialBuffer;
|
|
||||||
_lenInBytes = other._lenInBytes;
|
|
||||||
_dataType = other._dataType;
|
|
||||||
_workspace = other._workspace;
|
|
||||||
_isOwnerPrimary = other._isOwnerPrimary;
|
|
||||||
_isOwnerSpecial = other._isOwnerSpecial;
|
|
||||||
|
|
||||||
copyCounters(other);
|
|
||||||
|
|
||||||
other._primaryBuffer = other._specialBuffer = nullptr;
|
|
||||||
other.setAllocFlags(false, false);
|
|
||||||
other._lenInBytes = 0;
|
|
||||||
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
|
||||||
void* DataBuffer::primary() {
|
|
||||||
return _primaryBuffer;
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
|
||||||
void* DataBuffer::special() {
|
|
||||||
return _specialBuffer;
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
|
||||||
DataType DataBuffer::getDataType() {
|
|
||||||
return _dataType;
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
|
||||||
size_t DataBuffer::getLenInBytes() const {
|
|
||||||
return _lenInBytes;
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
|
||||||
template <typename T>
|
|
||||||
T* DataBuffer::primaryAsT() {
|
|
||||||
return reinterpret_cast<T*>(_primaryBuffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
|
||||||
template <typename T>
|
|
||||||
T* DataBuffer::specialAsT() {
|
|
||||||
return reinterpret_cast<T*>(_specialBuffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
|
||||||
void DataBuffer::allocatePrimary() {
|
|
||||||
|
|
||||||
if (_primaryBuffer == nullptr && getLenInBytes() > 0) {
|
|
||||||
ALLOCATE(_primaryBuffer, _workspace, getLenInBytes(), int8_t);
|
|
||||||
_isOwnerPrimary = true;
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
|
||||||
void DataBuffer::setAllocFlags(const bool isOwnerPrimary, const bool isOwnerSpecial) {
|
|
||||||
|
|
||||||
_isOwnerPrimary = isOwnerPrimary;
|
|
||||||
_isOwnerSpecial = isOwnerSpecial;
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
|
||||||
void DataBuffer::deletePrimary() {
|
|
||||||
|
|
||||||
if(_isOwnerPrimary && _primaryBuffer != nullptr && getLenInBytes() != 0) {
|
|
||||||
auto p = reinterpret_cast<int8_t*>(_primaryBuffer);
|
|
||||||
RELEASE(p, _workspace);
|
|
||||||
_primaryBuffer = nullptr;
|
|
||||||
_isOwnerPrimary = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
|
||||||
void DataBuffer::deleteBuffers() {
|
|
||||||
|
|
||||||
deletePrimary();
|
|
||||||
deleteSpecial();
|
|
||||||
_lenInBytes = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
|
||||||
DataBuffer::~DataBuffer() {
|
|
||||||
|
|
||||||
deleteBuffers();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -42,6 +42,8 @@ namespace nd4j {
|
||||||
QINT16 = 16,
|
QINT16 = 16,
|
||||||
BFLOAT16 = 17,
|
BFLOAT16 = 17,
|
||||||
UTF8 = 50,
|
UTF8 = 50,
|
||||||
|
UTF16 = 51,
|
||||||
|
UTF32 = 52,
|
||||||
ANY = 100,
|
ANY = 100,
|
||||||
AUTO = 200,
|
AUTO = 200,
|
||||||
};
|
};
|
||||||
|
|
|
@ -0,0 +1,71 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <dll.h>
|
||||||
|
#include <array/DataBuffer.h>
|
||||||
|
#include <array/DataType.h>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#ifndef LIBND4J_INTEROPDATABUFFER_H
|
||||||
|
#define LIBND4J_INTEROPDATABUFFER_H
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
/**
|
||||||
|
* This class is a wrapper for DataBuffer, suitable for sharing DataBuffer between front-end and back-end languages
|
||||||
|
*/
|
||||||
|
class ND4J_EXPORT InteropDataBuffer {
|
||||||
|
private:
|
||||||
|
std::shared_ptr<DataBuffer> _dataBuffer;
|
||||||
|
uint64_t _offset = 0;
|
||||||
|
public:
|
||||||
|
InteropDataBuffer(InteropDataBuffer &dataBuffer, uint64_t length, uint64_t offset);
|
||||||
|
InteropDataBuffer(std::shared_ptr<DataBuffer> databuffer);
|
||||||
|
InteropDataBuffer(size_t elements, nd4j::DataType dtype, bool allocateBoth);
|
||||||
|
~InteropDataBuffer() = default;
|
||||||
|
|
||||||
|
#ifndef __JAVACPP_HACK__
|
||||||
|
std::shared_ptr<DataBuffer> getDataBuffer() const;
|
||||||
|
std::shared_ptr<DataBuffer> dataBuffer();
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void* primary() const;
|
||||||
|
void* special() const;
|
||||||
|
|
||||||
|
uint64_t offset() const ;
|
||||||
|
void setOffset(uint64_t offset);
|
||||||
|
|
||||||
|
void setPrimary(void* ptr, size_t length);
|
||||||
|
void setSpecial(void* ptr, size_t length);
|
||||||
|
|
||||||
|
void expand(size_t newlength);
|
||||||
|
|
||||||
|
int deviceId() const;
|
||||||
|
void setDeviceId(int deviceId);
|
||||||
|
|
||||||
|
static void registerSpecialUse(const std::vector<const InteropDataBuffer*>& writeList, const std::vector<const InteropDataBuffer*>& readList);
|
||||||
|
static void prepareSpecialUse(const std::vector<const InteropDataBuffer*>& writeList, const std::vector<const InteropDataBuffer*>& readList, bool synchronizeWritables = false);
|
||||||
|
|
||||||
|
static void registerPrimaryUse(const std::vector<const InteropDataBuffer*>& writeList, const std::vector<const InteropDataBuffer*>& readList);
|
||||||
|
static void preparePrimaryUse(const std::vector<const InteropDataBuffer*>& writeList, const std::vector<const InteropDataBuffer*>& readList, bool synchronizeWritables = false);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#endif //LIBND4J_INTEROPDATABUFFER_H
|
|
@ -23,6 +23,24 @@
|
||||||
#include <DataTypeUtils.h>
|
#include <DataTypeUtils.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
|
void DataBuffer::expand(const uint64_t size) {
|
||||||
|
if (size > _lenInBytes) {
|
||||||
|
// allocate new buffer
|
||||||
|
int8_t *newBuffer = nullptr;
|
||||||
|
ALLOCATE(newBuffer, _workspace, size, int8_t);
|
||||||
|
|
||||||
|
// copy data from existing buffer
|
||||||
|
std::memcpy(newBuffer, _primaryBuffer, _lenInBytes);
|
||||||
|
|
||||||
|
if (_isOwnerPrimary) {
|
||||||
|
RELEASE(reinterpret_cast<int8_t *>(_primaryBuffer), _workspace);
|
||||||
|
}
|
||||||
|
|
||||||
|
_primaryBuffer = newBuffer;
|
||||||
|
_lenInBytes = size;
|
||||||
|
_isOwnerPrimary = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
void DataBuffer::setCountersToZero() {
|
void DataBuffer::setCountersToZero() {
|
||||||
|
@ -99,14 +117,17 @@ void DataBuffer::allocateSpecial() {
|
||||||
void DataBuffer::migrate() {
|
void DataBuffer::migrate() {
|
||||||
|
|
||||||
}
|
}
|
||||||
///////////////////////////////////////////////////////////////////////
|
|
||||||
void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) {
|
|
||||||
if (src._lenInBytes < dst._lenInBytes)
|
|
||||||
throw std::runtime_error("DataBuffer::memcpy: Source data buffer is smaller than destination");
|
|
||||||
|
|
||||||
std::memcpy(dst._primaryBuffer, src._primaryBuffer, dst._lenInBytes);
|
/////////////////////////
|
||||||
|
void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) {
|
||||||
|
if (src._lenInBytes > dst._lenInBytes)
|
||||||
|
throw std::runtime_error("DataBuffer::memcpy: Source data buffer is larger than destination");
|
||||||
|
|
||||||
|
std::memcpy(dst._primaryBuffer, src._primaryBuffer, src._lenInBytes);
|
||||||
|
dst.readPrimary();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
void DataBuffer::writePrimary() const { }
|
void DataBuffer::writePrimary() const { }
|
||||||
void DataBuffer::writeSpecial() const { }
|
void DataBuffer::writeSpecial() const { }
|
||||||
|
|
|
@ -25,6 +25,40 @@
|
||||||
#include <exceptions/cuda_exception.h>
|
#include <exceptions/cuda_exception.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
|
void DataBuffer::expand(const uint64_t size) {
|
||||||
|
if (size > _lenInBytes) {
|
||||||
|
// allocate new buffer
|
||||||
|
int8_t *newBuffer = nullptr;
|
||||||
|
int8_t *newSpecialBuffer = nullptr;
|
||||||
|
ALLOCATE_SPECIAL(newSpecialBuffer, _workspace, size, int8_t);
|
||||||
|
|
||||||
|
// copy data from existing buffer
|
||||||
|
if (_primaryBuffer != nullptr) {
|
||||||
|
// there's non-zero chance that primary buffer doesn't exist yet
|
||||||
|
ALLOCATE(newBuffer, _workspace, size, int8_t);
|
||||||
|
std::memcpy(newBuffer, _primaryBuffer, _lenInBytes);
|
||||||
|
|
||||||
|
if (_isOwnerPrimary) {
|
||||||
|
auto ipb = reinterpret_cast<int8_t *>(_primaryBuffer);
|
||||||
|
RELEASE(ipb, _workspace);
|
||||||
|
}
|
||||||
|
|
||||||
|
_primaryBuffer = newBuffer;
|
||||||
|
_isOwnerPrimary = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
cudaMemcpy(newSpecialBuffer, _specialBuffer, _lenInBytes, cudaMemcpyDeviceToDevice);
|
||||||
|
|
||||||
|
if (_isOwnerSpecial) {
|
||||||
|
auto isb = reinterpret_cast<int8_t *>(_specialBuffer);
|
||||||
|
RELEASE_SPECIAL(isb, _workspace);
|
||||||
|
}
|
||||||
|
|
||||||
|
_specialBuffer = newSpecialBuffer;
|
||||||
|
_lenInBytes = size;
|
||||||
|
_isOwnerSpecial = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
void DataBuffer::allocateSpecial() {
|
void DataBuffer::allocateSpecial() {
|
||||||
|
@ -37,8 +71,9 @@ void DataBuffer::allocateSpecial() {
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
void DataBuffer::syncToPrimary(const LaunchContext* context, const bool forceSync) {
|
void DataBuffer::syncToPrimary(const LaunchContext* context, const bool forceSync) {
|
||||||
if(isPrimaryActual() && !forceSync)
|
if(isPrimaryActual() && !forceSync) {
|
||||||
return;
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
allocatePrimary();
|
allocatePrimary();
|
||||||
|
|
||||||
|
@ -46,7 +81,9 @@ void DataBuffer::syncToPrimary(const LaunchContext* context, const bool forceSyn
|
||||||
if (res != 0)
|
if (res != 0)
|
||||||
throw cuda_exception::build("DataBuffer::syncToPrimary failed to to some previous kernel failre", res);
|
throw cuda_exception::build("DataBuffer::syncToPrimary failed to to some previous kernel failre", res);
|
||||||
|
|
||||||
cudaMemcpy(_primaryBuffer, _specialBuffer, getLenInBytes(), cudaMemcpyDeviceToHost);
|
res = cudaMemcpy(_primaryBuffer, _specialBuffer, getLenInBytes(), cudaMemcpyDeviceToHost);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("DataBuffer::syncToPrimary cudaMemcpy failed", res);
|
||||||
|
|
||||||
readPrimary();
|
readPrimary();
|
||||||
}
|
}
|
||||||
|
@ -54,13 +91,19 @@ void DataBuffer::syncToPrimary(const LaunchContext* context, const bool forceSyn
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
void DataBuffer::syncToSpecial(const bool forceSync) {
|
void DataBuffer::syncToSpecial(const bool forceSync) {
|
||||||
|
// in this case there's nothing to do here
|
||||||
if(isSpecialActual() && !forceSync)
|
if (_primaryBuffer == nullptr)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
|
if(isSpecialActual() && !forceSync) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
allocateSpecial();
|
allocateSpecial();
|
||||||
|
|
||||||
cudaMemcpy(_specialBuffer, _primaryBuffer, getLenInBytes(), cudaMemcpyHostToDevice);
|
auto res = cudaMemcpy(_specialBuffer, _primaryBuffer, getLenInBytes(), cudaMemcpyHostToDevice);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("DataBuffer::syncToSpecial cudaMemcpy failed", res);
|
||||||
|
|
||||||
readSpecial();
|
readSpecial();
|
||||||
}
|
}
|
||||||
|
@ -97,19 +140,6 @@ void DataBuffer::copyCounters(const DataBuffer& other) {
|
||||||
_readPrimary.store(other._writeSpecial);
|
_readPrimary.store(other._writeSpecial);
|
||||||
_readSpecial.store(other._writePrimary);
|
_readSpecial.store(other._writePrimary);
|
||||||
}
|
}
|
||||||
////////////////////////////////////////////////////////////////////////
|
|
||||||
void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) {
|
|
||||||
if (src._lenInBytes < dst._lenInBytes)
|
|
||||||
throw std::runtime_error("DataBuffer::memcpy: Source data buffer is smaller than destination");
|
|
||||||
|
|
||||||
if (src.isSpecialActual()) {
|
|
||||||
cudaMemcpy(dst._specialBuffer, src._specialBuffer, dst.getLenInBytes(), cudaMemcpyDeviceToDevice);
|
|
||||||
} else if (src.isPrimaryActual()) {
|
|
||||||
cudaMemcpy(dst._specialBuffer, src._primaryBuffer, dst.getLenInBytes(), cudaMemcpyHostToDevice);
|
|
||||||
}
|
|
||||||
|
|
||||||
dst.writeSpecial();
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
void DataBuffer::copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes, const Nd4jLong offsetThis, const Nd4jLong offsetOther) { // copies only to special buffer
|
void DataBuffer::copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes, const Nd4jLong offsetThis, const Nd4jLong offsetOther) { // copies only to special buffer
|
||||||
|
@ -176,8 +206,11 @@ void DataBuffer::allocateBuffers(const bool allocBoth) { // always allocate s
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
void DataBuffer::setToZeroBuffers(const bool both) {
|
void DataBuffer::setToZeroBuffers(const bool both) {
|
||||||
|
cudaMemsetAsync(special(), 0, getLenInBytes(), *LaunchContext::defaultContext()->getCudaStream());
|
||||||
|
auto res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaStream());
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("DataBuffer::setToZeroBuffers: streamSync failed!", res);
|
||||||
|
|
||||||
cudaMemset(special(), 0, getLenInBytes());
|
|
||||||
writeSpecial();
|
writeSpecial();
|
||||||
|
|
||||||
if(both) {
|
if(both) {
|
||||||
|
@ -186,12 +219,37 @@ void DataBuffer::setToZeroBuffers(const bool both) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/////////////////////////
|
||||||
|
void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) {
|
||||||
|
if (src._lenInBytes > dst._lenInBytes)
|
||||||
|
throw std::runtime_error("DataBuffer::memcpy: Source data buffer is larger than destination");
|
||||||
|
|
||||||
|
|
||||||
|
int res = 0;
|
||||||
|
if (src.isSpecialActual()) {
|
||||||
|
res = cudaMemcpyAsync(dst._specialBuffer, src._specialBuffer, src.getLenInBytes(), cudaMemcpyDeviceToDevice, *LaunchContext::defaultContext()->getCudaStream());
|
||||||
|
} else if (src.isPrimaryActual()) {
|
||||||
|
res = cudaMemcpyAsync(dst._specialBuffer, src._primaryBuffer, src.getLenInBytes(), cudaMemcpyHostToDevice, *LaunchContext::defaultContext()->getCudaStream());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("DataBuffer::memcpy: cudaMemcpyAsync failed!", res);
|
||||||
|
|
||||||
|
res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaStream());
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("DataBuffer::memcpy: streamSync failed!", res);
|
||||||
|
|
||||||
|
dst.writeSpecial();
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
void DataBuffer::migrate() {
|
void DataBuffer::migrate() {
|
||||||
memory::Workspace* newWorkspace = nullptr;
|
memory::Workspace* newWorkspace = nullptr;
|
||||||
void* newBuffer;
|
void* newBuffer;
|
||||||
ALLOCATE_SPECIAL(newBuffer, newWorkspace, getLenInBytes(), int8_t);
|
ALLOCATE_SPECIAL(newBuffer, newWorkspace, getLenInBytes(), int8_t);
|
||||||
cudaMemcpy(newBuffer, _specialBuffer, getLenInBytes(), cudaMemcpyDeviceToDevice);
|
auto res = cudaMemcpy(newBuffer, _specialBuffer, getLenInBytes(), cudaMemcpyDeviceToDevice);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("DataBuffer::migrate: cudaMemcpyAsync failed!", res);
|
||||||
|
|
||||||
if (_isOwnerSpecial) {
|
if (_isOwnerSpecial) {
|
||||||
// now we're releasing original buffer
|
// now we're releasing original buffer
|
||||||
|
@ -203,7 +261,7 @@ void DataBuffer::migrate() {
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
void DataBuffer::writePrimary() const { _writePrimary = ++_counter; }
|
void DataBuffer::writePrimary() const {_writePrimary = ++_counter; }
|
||||||
void DataBuffer::writeSpecial() const { _writeSpecial = ++_counter; }
|
void DataBuffer::writeSpecial() const { _writeSpecial = ++_counter; }
|
||||||
void DataBuffer::readPrimary() const { _readPrimary = ++_counter; }
|
void DataBuffer::readPrimary() const { _readPrimary = ++_counter; }
|
||||||
void DataBuffer::readSpecial() const { _readSpecial = ++_counter; }
|
void DataBuffer::readSpecial() const { _readSpecial = ++_counter; }
|
||||||
|
|
|
@ -0,0 +1,301 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <array/DataBuffer.h>
|
||||||
|
#include <helpers/logger.h>
|
||||||
|
#include <array/DataTypeUtils.h>
|
||||||
|
#include <execution/AffinityManager.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
///// IMLEMENTATION OF COMMON METHODS /////
|
||||||
|
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
// default constructor
|
||||||
|
DataBuffer::DataBuffer() {
|
||||||
|
|
||||||
|
_primaryBuffer = nullptr;
|
||||||
|
_specialBuffer = nullptr;
|
||||||
|
_lenInBytes = 0;
|
||||||
|
_dataType = INT8;
|
||||||
|
_workspace = nullptr;
|
||||||
|
_isOwnerPrimary = false;
|
||||||
|
_isOwnerSpecial = false;
|
||||||
|
_deviceId = nd4j::AffinityManager::currentDeviceId();
|
||||||
|
|
||||||
|
setCountersToZero();
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
// copy constructor
|
||||||
|
DataBuffer::DataBuffer(const DataBuffer &other) {
|
||||||
|
|
||||||
|
throw std::runtime_error("DataBuffer copy constructor: we don't expect using of this constructor!");
|
||||||
|
|
||||||
|
_lenInBytes = other._lenInBytes;
|
||||||
|
_dataType = other._dataType;
|
||||||
|
_workspace = other._workspace;
|
||||||
|
|
||||||
|
_primaryBuffer = nullptr;
|
||||||
|
_specialBuffer = nullptr;
|
||||||
|
|
||||||
|
_deviceId.store(other._deviceId.load());
|
||||||
|
|
||||||
|
setCountersToZero();
|
||||||
|
|
||||||
|
allocateBuffers();
|
||||||
|
copyBufferFrom(other);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
DataBuffer::DataBuffer(void* primary, void* special,
|
||||||
|
const size_t lenInBytes, const DataType dataType,
|
||||||
|
const bool isOwnerPrimary, const bool isOwnerSpecial,
|
||||||
|
memory::Workspace* workspace) {
|
||||||
|
|
||||||
|
if (primary == nullptr && special == nullptr)
|
||||||
|
throw std::runtime_error("DataBuffer constructor: can't be initialized with both nullptr buffers !");
|
||||||
|
|
||||||
|
_primaryBuffer = primary;
|
||||||
|
_specialBuffer = special;
|
||||||
|
_lenInBytes = lenInBytes;
|
||||||
|
_dataType = dataType;
|
||||||
|
_workspace = workspace;
|
||||||
|
_isOwnerPrimary = isOwnerPrimary;
|
||||||
|
_isOwnerSpecial = isOwnerSpecial;
|
||||||
|
_deviceId = nd4j::AffinityManager::currentDeviceId();
|
||||||
|
|
||||||
|
setCountersToZero();
|
||||||
|
|
||||||
|
if(primary != nullptr)
|
||||||
|
readPrimary();
|
||||||
|
if(special != nullptr)
|
||||||
|
readSpecial();
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
DataBuffer::DataBuffer(void* primary, const size_t lenInBytes, const DataType dataType, const bool isOwnerPrimary, memory::Workspace* workspace):
|
||||||
|
DataBuffer(primary, nullptr, lenInBytes, dataType, isOwnerPrimary, false, workspace) {
|
||||||
|
|
||||||
|
syncToSpecial(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
// copies data from hostBuffer to own memory buffer
|
||||||
|
DataBuffer::DataBuffer(const void* hostBuffer, const DataType dataType, const size_t lenInBytes, memory::Workspace* workspace) {
|
||||||
|
|
||||||
|
if (hostBuffer == nullptr)
|
||||||
|
throw std::runtime_error("DataBuffer constructor: can't be initialized with nullptr host buffer !");
|
||||||
|
if (lenInBytes == 0)
|
||||||
|
throw std::runtime_error("DataBuffer constructor: can't be initialized with zero length !");
|
||||||
|
|
||||||
|
_primaryBuffer = nullptr;
|
||||||
|
_specialBuffer = nullptr;
|
||||||
|
_lenInBytes = lenInBytes;
|
||||||
|
_dataType = dataType;
|
||||||
|
_workspace = workspace;
|
||||||
|
|
||||||
|
_deviceId = nd4j::AffinityManager::currentDeviceId();
|
||||||
|
|
||||||
|
setCountersToZero();
|
||||||
|
|
||||||
|
allocateBuffers();
|
||||||
|
|
||||||
|
copyBufferFromHost(hostBuffer, lenInBytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
DataBuffer::DataBuffer(const size_t lenInBytes, const DataType dataType, memory::Workspace* workspace, const bool allocBoth) {
|
||||||
|
|
||||||
|
_dataType = dataType;
|
||||||
|
_workspace = workspace;
|
||||||
|
_lenInBytes = lenInBytes;
|
||||||
|
|
||||||
|
_primaryBuffer = nullptr;
|
||||||
|
_specialBuffer = nullptr;
|
||||||
|
|
||||||
|
_deviceId = nd4j::AffinityManager::currentDeviceId();
|
||||||
|
|
||||||
|
setCountersToZero();
|
||||||
|
|
||||||
|
if(lenInBytes != 0) {
|
||||||
|
allocateBuffers(allocBoth);
|
||||||
|
writeSpecial();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
// move constructor
|
||||||
|
DataBuffer::DataBuffer(DataBuffer&& other) {
|
||||||
|
|
||||||
|
_primaryBuffer = other._primaryBuffer;
|
||||||
|
_specialBuffer = other._specialBuffer;
|
||||||
|
_lenInBytes = other._lenInBytes;
|
||||||
|
_dataType = other._dataType;
|
||||||
|
_workspace = other._workspace;
|
||||||
|
_isOwnerPrimary = other._isOwnerPrimary;
|
||||||
|
_isOwnerSpecial = other._isOwnerSpecial;
|
||||||
|
_deviceId.store(other._deviceId);
|
||||||
|
|
||||||
|
copyCounters(other);
|
||||||
|
|
||||||
|
other._primaryBuffer = other._specialBuffer = nullptr;
|
||||||
|
other.setAllocFlags(false, false);
|
||||||
|
other._lenInBytes = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
// assignment operator
|
||||||
|
DataBuffer& DataBuffer::operator=(const DataBuffer& other) {
|
||||||
|
|
||||||
|
if (this == &other)
|
||||||
|
return *this;
|
||||||
|
|
||||||
|
deleteBuffers();
|
||||||
|
|
||||||
|
_lenInBytes = other._lenInBytes;
|
||||||
|
_dataType = other._dataType;
|
||||||
|
_workspace = other._workspace;
|
||||||
|
|
||||||
|
allocateBuffers();
|
||||||
|
copyBufferFrom(other);
|
||||||
|
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
// move assignment operator
|
||||||
|
DataBuffer& DataBuffer::operator=(DataBuffer&& other) noexcept {
|
||||||
|
|
||||||
|
if (this == &other)
|
||||||
|
return *this;
|
||||||
|
|
||||||
|
deleteBuffers();
|
||||||
|
|
||||||
|
_primaryBuffer = other._primaryBuffer;
|
||||||
|
_specialBuffer = other._specialBuffer;
|
||||||
|
_lenInBytes = other._lenInBytes;
|
||||||
|
_dataType = other._dataType;
|
||||||
|
_workspace = other._workspace;
|
||||||
|
_isOwnerPrimary = other._isOwnerPrimary;
|
||||||
|
_isOwnerSpecial = other._isOwnerSpecial;
|
||||||
|
|
||||||
|
copyCounters(other);
|
||||||
|
|
||||||
|
other._primaryBuffer = other._specialBuffer = nullptr;
|
||||||
|
other.setAllocFlags(false, false);
|
||||||
|
other._lenInBytes = 0;
|
||||||
|
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
void* DataBuffer::primary() {
|
||||||
|
return _primaryBuffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
void* DataBuffer::special() {
|
||||||
|
return _specialBuffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
DataType DataBuffer::getDataType() {
|
||||||
|
return _dataType;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
size_t DataBuffer::getLenInBytes() const {
|
||||||
|
return _lenInBytes;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
void DataBuffer::allocatePrimary() {
|
||||||
|
|
||||||
|
if (_primaryBuffer == nullptr && getLenInBytes() > 0) {
|
||||||
|
ALLOCATE(_primaryBuffer, _workspace, getLenInBytes(), int8_t);
|
||||||
|
_isOwnerPrimary = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
void DataBuffer::setAllocFlags(const bool isOwnerPrimary, const bool isOwnerSpecial) {
|
||||||
|
|
||||||
|
_isOwnerPrimary = isOwnerPrimary;
|
||||||
|
_isOwnerSpecial = isOwnerSpecial;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
void DataBuffer::deletePrimary() {
|
||||||
|
|
||||||
|
if(_isOwnerPrimary && _primaryBuffer != nullptr && getLenInBytes() != 0) {
|
||||||
|
auto p = reinterpret_cast<int8_t*>(_primaryBuffer);
|
||||||
|
RELEASE(p, _workspace);
|
||||||
|
_primaryBuffer = nullptr;
|
||||||
|
_isOwnerPrimary = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
void DataBuffer::deleteBuffers() {
|
||||||
|
|
||||||
|
deletePrimary();
|
||||||
|
deleteSpecial();
|
||||||
|
_lenInBytes = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
DataBuffer::~DataBuffer() {
|
||||||
|
|
||||||
|
deleteBuffers();
|
||||||
|
}
|
||||||
|
|
||||||
|
void DataBuffer::setPrimaryBuffer(void *buffer, size_t length) {
|
||||||
|
if (_primaryBuffer != nullptr && _isOwnerPrimary) {
|
||||||
|
deletePrimary();
|
||||||
|
}
|
||||||
|
_primaryBuffer = buffer;
|
||||||
|
_isOwnerPrimary = false;
|
||||||
|
_lenInBytes = length * DataTypeUtils::sizeOf(_dataType);
|
||||||
|
}
|
||||||
|
|
||||||
|
void DataBuffer::setSpecialBuffer(void *buffer, size_t length) {
|
||||||
|
this->setSpecial(buffer, false);
|
||||||
|
_lenInBytes = length * DataTypeUtils::sizeOf(_dataType);
|
||||||
|
}
|
||||||
|
|
||||||
|
void DataBuffer::setDataType(DataType dataType) {
|
||||||
|
_dataType = dataType;
|
||||||
|
}
|
||||||
|
|
||||||
|
int DataBuffer::deviceId() const {
|
||||||
|
return _deviceId.load();
|
||||||
|
}
|
||||||
|
|
||||||
|
void DataBuffer::close() {
|
||||||
|
this->deleteBuffers();
|
||||||
|
}
|
||||||
|
|
||||||
|
void DataBuffer::setDeviceId(int deviceId) {
|
||||||
|
_deviceId = deviceId;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,146 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <array/InteropDataBuffer.h>
|
||||||
|
#include <array/DataTypeUtils.h>
|
||||||
|
#include <execution/AffinityManager.h>
|
||||||
|
#include <helpers/logger.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
InteropDataBuffer::InteropDataBuffer(InteropDataBuffer &dataBuffer, uint64_t length, uint64_t offset) {
|
||||||
|
_dataBuffer = dataBuffer.getDataBuffer();
|
||||||
|
|
||||||
|
// offset is always absolute to the original buffer
|
||||||
|
_offset = offset;
|
||||||
|
|
||||||
|
if (_offset + length > _dataBuffer->getLenInBytes()) {
|
||||||
|
throw std::runtime_error("offset + length is higher than original length");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
InteropDataBuffer::InteropDataBuffer(std::shared_ptr<DataBuffer> databuffer) {
|
||||||
|
_dataBuffer = databuffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
InteropDataBuffer::InteropDataBuffer(size_t elements, nd4j::DataType dtype, bool allocateBoth) {
|
||||||
|
if (elements == 0) {
|
||||||
|
_dataBuffer = std::make_shared<DataBuffer>();
|
||||||
|
_dataBuffer->setDataType(dtype);
|
||||||
|
} else {
|
||||||
|
_dataBuffer = std::make_shared<DataBuffer>(elements, dtype, nullptr, allocateBoth);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<DataBuffer> InteropDataBuffer::getDataBuffer() const {
|
||||||
|
return _dataBuffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<DataBuffer> InteropDataBuffer::dataBuffer() {
|
||||||
|
return _dataBuffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
void* InteropDataBuffer::primary() const {
|
||||||
|
return reinterpret_cast<int8_t *>(_dataBuffer->primary()) + _offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
void* InteropDataBuffer::special() const {
|
||||||
|
return reinterpret_cast<int8_t *>(_dataBuffer->special()) + _offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
void InteropDataBuffer::setPrimary(void* ptr, size_t length) {
|
||||||
|
_dataBuffer->setPrimaryBuffer(ptr, length);
|
||||||
|
}
|
||||||
|
|
||||||
|
void InteropDataBuffer::setSpecial(void* ptr, size_t length) {
|
||||||
|
_dataBuffer->setSpecialBuffer(ptr, length);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t InteropDataBuffer::offset() const {
|
||||||
|
return _offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
void InteropDataBuffer::setOffset(uint64_t offset) {
|
||||||
|
_offset = offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
int InteropDataBuffer::deviceId() const {
|
||||||
|
return _dataBuffer->deviceId();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void InteropDataBuffer::registerSpecialUse(const std::vector<const InteropDataBuffer*>& writeList, const std::vector<const InteropDataBuffer*>& readList) {
|
||||||
|
for (const auto &v:writeList) {
|
||||||
|
if (v == nullptr)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
v->getDataBuffer()->writeSpecial();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void InteropDataBuffer::prepareSpecialUse(const std::vector<const InteropDataBuffer*>& writeList, const std::vector<const InteropDataBuffer*>& readList, bool synchronizeWritables) {
|
||||||
|
auto currentDeviceId = nd4j::AffinityManager::currentDeviceId();
|
||||||
|
for (const auto &v:readList) {
|
||||||
|
if (v == nullptr)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
if (v->getDataBuffer()->deviceId() != currentDeviceId)
|
||||||
|
v->getDataBuffer()->migrate();
|
||||||
|
|
||||||
|
v->getDataBuffer()->syncToSpecial();
|
||||||
|
}
|
||||||
|
|
||||||
|
// we don't tick write list, only ensure the same device affinity
|
||||||
|
for (const auto &v:writeList) {
|
||||||
|
if (v == nullptr)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
// special case for legacy ops - views can be updated on host side, thus original array can be not updated
|
||||||
|
if (!v->getDataBuffer()->isSpecialActual())
|
||||||
|
v->getDataBuffer()->syncToSpecial();
|
||||||
|
|
||||||
|
if (v->getDataBuffer()->deviceId() != currentDeviceId)
|
||||||
|
v->getDataBuffer()->migrate();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void InteropDataBuffer::registerPrimaryUse(const std::vector<const InteropDataBuffer*>& writeList, const std::vector<const InteropDataBuffer*>& readList) {
|
||||||
|
for (const auto &v:writeList) {
|
||||||
|
if (v == nullptr)
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void InteropDataBuffer::preparePrimaryUse(const std::vector<const InteropDataBuffer*>& writeList, const std::vector<const InteropDataBuffer*>& readList, bool synchronizeWritables) {
|
||||||
|
for (const auto &v:readList) {
|
||||||
|
if (v == nullptr)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
v->getDataBuffer()->syncToPrimary(LaunchContext::defaultContext());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void InteropDataBuffer::expand(size_t newlength) {
|
||||||
|
_dataBuffer->expand(newlength * DataTypeUtils::sizeOf(_dataBuffer->getDataType()));
|
||||||
|
}
|
||||||
|
|
||||||
|
void InteropDataBuffer::setDeviceId(int deviceId) {
|
||||||
|
_dataBuffer->setDeviceId(deviceId);
|
||||||
|
}
|
||||||
|
}
|
|
@ -138,7 +138,7 @@ namespace nd4j {
|
||||||
if (res != 0)
|
if (res != 0)
|
||||||
throw cuda_exception::build("_reductionPointer allocation failed", res);
|
throw cuda_exception::build("_reductionPointer allocation failed", res);
|
||||||
|
|
||||||
res = cudaMalloc(reinterpret_cast<void**>(&_scalarPointer), 16);
|
res = cudaHostAlloc(reinterpret_cast<void**>(&_scalarPointer), 16, cudaHostAllocDefault);
|
||||||
if (res != 0)
|
if (res != 0)
|
||||||
throw cuda_exception::build("_scalarPointer allocation failed", res);
|
throw cuda_exception::build("_scalarPointer allocation failed", res);
|
||||||
|
|
||||||
|
|
|
@ -185,9 +185,11 @@ namespace nd4j {
|
||||||
|
|
||||||
void setInputArray(int index, NDArray *array, bool removable = false);
|
void setInputArray(int index, NDArray *array, bool removable = false);
|
||||||
void setInputArray(int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
|
void setInputArray(int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
|
||||||
|
void setInputArray(int index, void *databuffer, void *shapeInfo, void *specialShapeInfo);
|
||||||
|
|
||||||
void setOutputArray(int index, NDArray *array, bool removable = false);
|
void setOutputArray(int index, NDArray *array, bool removable = false);
|
||||||
void setOutputArray(int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
|
void setOutputArray(int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
|
||||||
|
void setOutputArray(int index, void *databuffer, void *shapeInfo, void *specialShapeInfo);
|
||||||
|
|
||||||
void setTArguments(double *arguments, int numberOfArguments);
|
void setTArguments(double *arguments, int numberOfArguments);
|
||||||
void setIArguments(Nd4jLong *arguments, int numberOfArguments);
|
void setIArguments(Nd4jLong *arguments, int numberOfArguments);
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
#include <Context.h>
|
#include <Context.h>
|
||||||
#include <helpers/ShapeUtils.h>
|
#include <helpers/ShapeUtils.h>
|
||||||
#include <graph/Context.h>
|
#include <graph/Context.h>
|
||||||
|
#include <array/InteropDataBuffer.h>
|
||||||
|
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
|
@ -426,6 +427,44 @@ namespace nd4j {
|
||||||
array->setContext(_context);
|
array->setContext(_context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Context::setInputArray(int index, void *vdatabuffer, void *shapeInfo, void *specialShapeInfo) {
|
||||||
|
auto dataBuffer = reinterpret_cast<InteropDataBuffer*>(vdatabuffer);
|
||||||
|
|
||||||
|
if (_fastpath_in.size() < index + 1)
|
||||||
|
_fastpath_in.resize(index+1);
|
||||||
|
|
||||||
|
NDArray *array;
|
||||||
|
if (dataBuffer != nullptr)
|
||||||
|
array = new NDArray(dataBuffer->dataBuffer(), reinterpret_cast<Nd4jLong *>(shapeInfo), nd4j::LaunchContext::defaultContext(), dataBuffer->offset() / DataTypeUtils::sizeOf(ArrayOptions::dataType(reinterpret_cast<Nd4jLong *>(shapeInfo))));
|
||||||
|
else
|
||||||
|
array = new NDArray(nullptr, nullptr, reinterpret_cast<Nd4jLong *>(shapeInfo));
|
||||||
|
|
||||||
|
_fastpath_in[index] = array;
|
||||||
|
_handles.emplace_back(array);
|
||||||
|
|
||||||
|
if (_context != nullptr)
|
||||||
|
array->setContext(_context);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Context::setOutputArray(int index, void *vdatabuffer, void *shapeInfo, void *specialShapeInfo) {
|
||||||
|
auto dataBuffer = reinterpret_cast<InteropDataBuffer*>(vdatabuffer);
|
||||||
|
|
||||||
|
if (_fastpath_out.size() < index + 1)
|
||||||
|
_fastpath_out.resize(index+1);
|
||||||
|
|
||||||
|
NDArray *array;
|
||||||
|
if (dataBuffer != nullptr)
|
||||||
|
array = new NDArray(dataBuffer->dataBuffer(), reinterpret_cast<Nd4jLong *>(shapeInfo), nd4j::LaunchContext::defaultContext(), dataBuffer->offset() / DataTypeUtils::sizeOf(ArrayOptions::dataType(reinterpret_cast<Nd4jLong *>(shapeInfo))));
|
||||||
|
else
|
||||||
|
array = new NDArray(nullptr, nullptr, reinterpret_cast<Nd4jLong *>(shapeInfo));
|
||||||
|
|
||||||
|
_fastpath_out[index] = array;
|
||||||
|
_handles.emplace_back(array);
|
||||||
|
|
||||||
|
if (_context != nullptr)
|
||||||
|
array->setContext(_context);
|
||||||
|
}
|
||||||
|
|
||||||
void Context::setTArguments(double *arguments, int numberOfArguments) {
|
void Context::setTArguments(double *arguments, int numberOfArguments) {
|
||||||
_tArgs.clear();
|
_tArgs.clear();
|
||||||
_tArgs.reserve(numberOfArguments);
|
_tArgs.reserve(numberOfArguments);
|
||||||
|
|
|
@ -43,6 +43,8 @@ enum DType:byte {
|
||||||
QINT16,
|
QINT16,
|
||||||
BFLOAT16 = 17,
|
BFLOAT16 = 17,
|
||||||
UTF8 = 50,
|
UTF8 = 50,
|
||||||
|
UTF16 = 51,
|
||||||
|
UTF32 = 52,
|
||||||
}
|
}
|
||||||
|
|
||||||
// this structure describe NDArray
|
// this structure describe NDArray
|
||||||
|
|
|
@ -34,8 +34,6 @@
|
||||||
#include <driver_types.h>
|
#include <driver_types.h>
|
||||||
#include <cuda_runtime_api.h>
|
#include <cuda_runtime_api.h>
|
||||||
|
|
||||||
#define checkCudaErrors(ERR) if (ERR != 0) {throw std::runtime_error("CUDA stream synchronization failed");}
|
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
#include <DebugInfo.h>
|
#include <DebugInfo.h>
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
|
|
|
@ -25,6 +25,8 @@
|
||||||
#include <op_boilerplate.h>
|
#include <op_boilerplate.h>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
#include <vector>
|
||||||
|
#include <NDArray.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
class ND4J_EXPORT StringUtils {
|
class ND4J_EXPORT StringUtils {
|
||||||
|
@ -53,6 +55,36 @@ namespace nd4j {
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method returns number of needle matches within haystack
|
||||||
|
* PLEASE NOTE: this method operates on 8-bit arrays interpreted as uint8
|
||||||
|
*
|
||||||
|
* @param haystack
|
||||||
|
* @param haystackLength
|
||||||
|
* @param needle
|
||||||
|
* @param needleLength
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
static uint64_t countSubarrays(const void *haystack, uint64_t haystackLength, const void *needle, uint64_t needleLength);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method returns number of bytes used for string NDArrays content
|
||||||
|
* PLEASE NOTE: this doesn't include header
|
||||||
|
*
|
||||||
|
* @param array
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
static uint64_t byteLength(const NDArray &array);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method splits a string into substring by delimiter
|
||||||
|
*
|
||||||
|
* @param haystack
|
||||||
|
* @param delimiter
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
static std::vector<std::string> split(const std::string &haystack, const std::string &delimiter);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,58 @@
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <helpers/StringUtils.h>
|
#include <helpers/StringUtils.h>
|
||||||
|
#include <exceptions/datatype_exception.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
|
static FORCEINLINE bool match(const uint8_t *haystack, const uint8_t *needle, uint64_t length) {
|
||||||
|
for (int e = 0; e < length; e++)
|
||||||
|
if (haystack[e] != needle[e])
|
||||||
|
return false;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t StringUtils::countSubarrays(const void *vhaystack, uint64_t haystackLength, const void *vneedle, uint64_t needleLength) {
|
||||||
|
auto haystack = reinterpret_cast<const uint8_t*>(vhaystack);
|
||||||
|
auto needle = reinterpret_cast<const uint8_t*>(vneedle);
|
||||||
|
|
||||||
|
uint64_t number = 0;
|
||||||
|
|
||||||
|
for (uint64_t e = 0; e < haystackLength - needleLength; e++) {
|
||||||
|
if (match(&haystack[e], needle, needleLength))
|
||||||
|
number++;
|
||||||
|
}
|
||||||
|
|
||||||
|
return number;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
uint64_t StringUtils::byteLength(const NDArray &array) {
|
||||||
|
if (!array.isS())
|
||||||
|
throw nd4j::datatype_exception::build("StringUtils::byteLength expects one of String types;", array.dataType());
|
||||||
|
|
||||||
|
uint64_t result = 0;
|
||||||
|
|
||||||
|
// our buffer stores offsets, and the last value is basically number of bytes used
|
||||||
|
auto buffer = array.bufferAsT<Nd4jLong>();
|
||||||
|
result = buffer[array.lengthOf()];
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> StringUtils::split(const std::string &haystack, const std::string &delimiter) {
|
||||||
|
std::vector<std::string> output;
|
||||||
|
|
||||||
|
std::string::size_type prev_pos = 0, pos = 0;
|
||||||
|
|
||||||
|
// iterating through the haystack till the end
|
||||||
|
while((pos = haystack.find(delimiter, pos)) != std::string::npos) {
|
||||||
|
output.emplace_back(haystack.substr(prev_pos, pos-prev_pos));
|
||||||
|
prev_pos = ++pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
output.emplace_back(haystack.substr(prev_pos, pos - prev_pos)); // Last word
|
||||||
|
|
||||||
|
return output;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,7 +20,6 @@
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <types/types.h>
|
#include <types/types.h>
|
||||||
#include <ShapeUtils.h>
|
|
||||||
#include <op_boilerplate.h>
|
#include <op_boilerplate.h>
|
||||||
#include <loops/reduce_bool.h>
|
#include <loops/reduce_bool.h>
|
||||||
#include <loops/legacy_ops.h>
|
#include <loops/legacy_ops.h>
|
||||||
|
|
|
@ -20,7 +20,6 @@
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <types/types.h>
|
#include <types/types.h>
|
||||||
#include <ShapeUtils.h>
|
|
||||||
#include <op_boilerplate.h>
|
#include <op_boilerplate.h>
|
||||||
#include <loops/reduce_float.h>
|
#include <loops/reduce_float.h>
|
||||||
#include <loops/legacy_ops.h>
|
#include <loops/legacy_ops.h>
|
||||||
|
|
|
@ -20,7 +20,6 @@
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <types/types.h>
|
#include <types/types.h>
|
||||||
#include <ShapeUtils.h>
|
|
||||||
#include <op_boilerplate.h>
|
#include <op_boilerplate.h>
|
||||||
#include <loops/reduce_long.h>
|
#include <loops/reduce_long.h>
|
||||||
#include <loops/legacy_ops.h>
|
#include <loops/legacy_ops.h>
|
||||||
|
|
|
@ -20,7 +20,6 @@
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <types/types.h>
|
#include <types/types.h>
|
||||||
#include <ShapeUtils.h>
|
|
||||||
#include <op_boilerplate.h>
|
#include <op_boilerplate.h>
|
||||||
#include <loops/reduce_same.h>
|
#include <loops/reduce_same.h>
|
||||||
#include <loops/legacy_ops.h>
|
#include <loops/legacy_ops.h>
|
||||||
|
|
|
@ -1624,4 +1624,9 @@
|
||||||
|
|
||||||
#define PARAMETRIC_D() [&] (Parameters &p) -> Context*
|
#define PARAMETRIC_D() [&] (Parameters &p) -> Context*
|
||||||
|
|
||||||
|
|
||||||
|
#ifdef __CUDABLAS__
|
||||||
|
#define checkCudaErrors(ERR) if (ERR != 0) {throw std::runtime_error("CUDA stream synchronization failed");}
|
||||||
|
#endif
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -40,6 +40,9 @@
|
||||||
#include <ops/declarable/headers/third_party.h>
|
#include <ops/declarable/headers/third_party.h>
|
||||||
#include <ops/declarable/headers/tests.h>
|
#include <ops/declarable/headers/tests.h>
|
||||||
#include <ops/declarable/headers/kernels.h>
|
#include <ops/declarable/headers/kernels.h>
|
||||||
|
#include <ops/declarable/headers/strings.h>
|
||||||
|
#include <ops/declarable/headers/compat.h>
|
||||||
|
#include <ops/declarable/headers/util.h>
|
||||||
#include <ops/declarable/headers/BarnesHutTsne.h>
|
#include <ops/declarable/headers/BarnesHutTsne.h>
|
||||||
#include <ops/declarable/headers/images.h>
|
#include <ops/declarable/headers/images.h>
|
||||||
#include <dll.h>
|
#include <dll.h>
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
This folder contains operations required for compatibility with TF and other frameworks.
|
|
@ -0,0 +1,73 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
#if NOT_EXCLUDED(OP_split_string)
|
||||||
|
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <ops/declarable/helpers/sparse_to_dense.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
CUSTOM_OP_IMPL(compat_sparse_to_dense, 4, 1, false, 0, 0) {
|
||||||
|
auto indices = INPUT_VARIABLE(0);
|
||||||
|
auto shape = INPUT_VARIABLE(1);
|
||||||
|
auto values = INPUT_VARIABLE(2);
|
||||||
|
NDArray *def = nullptr;
|
||||||
|
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
if (block.width() > 3)
|
||||||
|
def = INPUT_VARIABLE(3);
|
||||||
|
|
||||||
|
nd4j::ops::helpers::compat_sparse_to_dense(*values, *indices, def, *output);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
};
|
||||||
|
|
||||||
|
DECLARE_SHAPE_FN(compat_sparse_to_dense) {
|
||||||
|
auto indices = INPUT_VARIABLE(0);
|
||||||
|
auto shape = INPUT_VARIABLE(1);
|
||||||
|
auto values = INPUT_VARIABLE(2);
|
||||||
|
|
||||||
|
if (block.width() > 3) {
|
||||||
|
auto def = INPUT_VARIABLE(3);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(def->dataType() == values->dataType() && def->isScalar(), 0, "compat_sparse_to_dense: default value must be a scalar of the same data type as actual values")
|
||||||
|
};
|
||||||
|
|
||||||
|
auto dtype = values->dataType();
|
||||||
|
|
||||||
|
// basically output shape is defined by the type of input, and desired shape input
|
||||||
|
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'c', shape->getBufferAsVector<Nd4jLong>()));
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(compat_sparse_to_dense) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes(0, {ALL_INTS}) // indices
|
||||||
|
->setAllowedInputTypes(1, {ALL_INTS}) // shape
|
||||||
|
->setAllowedInputTypes(2,nd4j::DataType::ANY) // sparse values
|
||||||
|
->setAllowedInputTypes(3,nd4j::DataType::ANY) // default value
|
||||||
|
->setAllowedOutputTypes(nd4j::DataType::ANY);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -0,0 +1,140 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
#if NOT_EXCLUDED(OP_split_string)
|
||||||
|
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <helpers/StringUtils.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
CUSTOM_OP_IMPL(compat_string_split, 2, 2, false, 0, 0) {
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto delim = INPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
auto indices = OUTPUT_VARIABLE(0);
|
||||||
|
auto values = OUTPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
auto d = delim->e<std::string>(0);
|
||||||
|
|
||||||
|
input->syncToHost();
|
||||||
|
delim->syncToHost();
|
||||||
|
|
||||||
|
// output rank N+1 wrt input rank
|
||||||
|
std::vector<Nd4jLong> ocoords(input->rankOf() + 1);
|
||||||
|
std::vector<Nd4jLong> icoords(input->rankOf());
|
||||||
|
|
||||||
|
// getting buffer lengths
|
||||||
|
// FIXME: it'll be bigger, since it'll include delimiters,
|
||||||
|
auto outputLength = StringUtils::byteLength(*input);
|
||||||
|
|
||||||
|
uint64_t ss = 0L;
|
||||||
|
Nd4jLong ic = 0L;
|
||||||
|
// loop through each string within tensor
|
||||||
|
for (auto e = 0L; e < input->lengthOf(); e++) {
|
||||||
|
// now we should map substring to indices
|
||||||
|
auto s = input->e<std::string>(e);
|
||||||
|
|
||||||
|
// getting base index
|
||||||
|
shape::index2coords(e, input->shapeInfo(), icoords.data());
|
||||||
|
|
||||||
|
// getting number of substrings
|
||||||
|
auto cnt = StringUtils::countSubarrays(s.c_str(), s.length(), d.c_str(), d.length()) + 1;
|
||||||
|
|
||||||
|
// filling output indices
|
||||||
|
for (uint64_t f = 0; f < cnt; f++) {
|
||||||
|
for (auto v: icoords)
|
||||||
|
indices->p(ic++, v);
|
||||||
|
|
||||||
|
// last index
|
||||||
|
indices->p(ic++, f);
|
||||||
|
}
|
||||||
|
|
||||||
|
ss += cnt;
|
||||||
|
}
|
||||||
|
|
||||||
|
// process strings now
|
||||||
|
std::vector<std::string> strings;
|
||||||
|
for (auto e = 0L; e < input->lengthOf(); e++) {
|
||||||
|
auto split = StringUtils::split(input->e<std::string>(e), d);
|
||||||
|
|
||||||
|
for (const auto &s:split)
|
||||||
|
strings.emplace_back(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
// now once we have all strings in single vector time to fill
|
||||||
|
auto tmp = NDArrayFactory::string('c', {(Nd4jLong) strings.size()}, strings);
|
||||||
|
auto blen = StringUtils::byteLength(tmp) + ShapeUtils::stringBufferHeaderRequirements(strings.size());
|
||||||
|
|
||||||
|
// for CUDA mostly
|
||||||
|
values->dataBuffer()->allocatePrimary();
|
||||||
|
values->dataBuffer()->expand(blen);
|
||||||
|
memcpy(values->buffer(), tmp.buffer(), blen);
|
||||||
|
values->tickWriteHost();
|
||||||
|
|
||||||
|
// special case, for future use
|
||||||
|
indices->syncToDevice();
|
||||||
|
values->syncToDevice();
|
||||||
|
|
||||||
|
// we have to tick buffers
|
||||||
|
values->dataBuffer()->writePrimary();
|
||||||
|
values->dataBuffer()->readSpecial();
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
};
|
||||||
|
|
||||||
|
DECLARE_SHAPE_FN(compat_string_split) {
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto delim = INPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
auto d = delim->e<std::string>(0);
|
||||||
|
|
||||||
|
// count number of delimiter substrings in all strings within input tensor
|
||||||
|
uint64_t cnt = 0;
|
||||||
|
for (auto e = 0L; e < input->lengthOf(); e++) {
|
||||||
|
// FIXME: bad, not UTF-compatible
|
||||||
|
auto s = input->e<std::string>(e);
|
||||||
|
|
||||||
|
// each substring we see in haystack, splits string in two parts. so we should add 1 to the number of subarrays
|
||||||
|
cnt += StringUtils::countSubarrays(s.c_str(), s.length(), d.c_str(), d.length()) + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// shape calculations
|
||||||
|
// virtual tensor rank will be N+1, for N rank input array, where data will be located at the biggest dimension
|
||||||
|
// values tensor is going to be vector always
|
||||||
|
// indices tensor is going to be vector with length equal to values.length * output rank
|
||||||
|
|
||||||
|
auto valuesShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(cnt, nd4j::DataType::UTF8);
|
||||||
|
auto indicesShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(cnt * (input->rankOf() + 1), nd4j::DataType::INT64);
|
||||||
|
|
||||||
|
return SHAPELIST(indicesShape, valuesShape);
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(compat_string_split) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes({ALL_STRINGS})
|
||||||
|
->setAllowedOutputTypes(0, {ALL_INDICES})
|
||||||
|
->setAllowedOutputTypes(1, {ALL_STRINGS});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -47,8 +47,7 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
// just memcpy data
|
// just memcpy data
|
||||||
// output->dataBuffer()->copyBufferFrom(*input->dataBuffer()); // as variant
|
DataBuffer::memcpy(*output->dataBuffer(), *input->dataBuffer());
|
||||||
DataBuffer::memcpy(*output->dataBuffer(), *input->dataBuffer()); // this is modern approach
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,50 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
#if NOT_EXCLUDED(OP_split_string)
|
||||||
|
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
CUSTOM_OP_IMPL(split_string, 2, 1, true, 0, 0) {
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto delim = INPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
};
|
||||||
|
|
||||||
|
DECLARE_SHAPE_FN(split_string) {
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto delim = INPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
return SHAPELIST();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(split_string) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes({ALL_STRINGS})
|
||||||
|
->setAllowedOutputTypes({ALL_STRINGS});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -0,0 +1,52 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
#if NOT_EXCLUDED(OP_print_affinity)
|
||||||
|
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <ops/declarable/helpers/print_variable.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
CUSTOM_OP_IMPL(print_affinity, 1, 1, true, 0, 0) {
|
||||||
|
// TODO: make this op compatible with ArrayList etc
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
nd4j_printf("<Node %i>: Actuality: [HOST: %s; DEVICE: %s]; affinity: [%i]; Pointers: [HOST: %p; DEVICE: %p]; DataBuffer length: %lld\n", block.nodeId(), input->isActualOnHostSide() ? "true" : "false", input->isActualOnDeviceSide() ? "true" : "false", input->dataBuffer()->deviceId(), input->getBuffer(), input->getSpecialBuffer(), input->dataBuffer()->getLenInBytes());
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(print_affinity) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes(0, nd4j::DataType::ANY)
|
||||||
|
->setAllowedInputTypes(1, {ALL_STRINGS})
|
||||||
|
->setAllowedOutputTypes(0, nd4j::DataType::INT32);
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_SHAPE_FN(print_affinity) {
|
||||||
|
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT32));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -0,0 +1,77 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
#if NOT_EXCLUDED(OP_print_variable)
|
||||||
|
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <ops/declarable/helpers/print_variable.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
CUSTOM_OP_IMPL(print_variable, 1, 1, true, 0, 0) {
|
||||||
|
// TODO: make this op compatible with ArrayList etc
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
std::string str;
|
||||||
|
|
||||||
|
if (block.width() == 2) {
|
||||||
|
auto message = INPUT_VARIABLE(1);
|
||||||
|
REQUIRE_TRUE(message->isS(), 0, "print_variable: message variable must be a String");
|
||||||
|
|
||||||
|
str = message->e<std::string>(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool printSpecial = false;
|
||||||
|
if (block.numB() > 0)
|
||||||
|
printSpecial = B_ARG(0);
|
||||||
|
|
||||||
|
if (printSpecial && !nd4j::Environment::getInstance()->isCPU()) {
|
||||||
|
// only specific backends support special printout. for cpu-based backends it's the same as regular print
|
||||||
|
|
||||||
|
if (block.width() == 2)
|
||||||
|
helpers::print_special(*block.launchContext(), *input, str);
|
||||||
|
else
|
||||||
|
helpers::print_special(*block.launchContext(), *input);
|
||||||
|
} else {
|
||||||
|
// optionally add message to the print out
|
||||||
|
if (block.width() == 2) {
|
||||||
|
input->printIndexedBuffer(str.c_str());
|
||||||
|
} else {
|
||||||
|
input->printIndexedBuffer();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(print_variable) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes(0, nd4j::DataType::ANY)
|
||||||
|
->setAllowedInputTypes(1, {ALL_STRINGS})
|
||||||
|
->setAllowedOutputTypes(0, nd4j::DataType::INT32);
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_SHAPE_FN(print_variable) {
|
||||||
|
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT32));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -0,0 +1,54 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef SAMEDIFF_COMPAT_H
|
||||||
|
#define SAMEDIFF_COMPAT_H
|
||||||
|
|
||||||
|
#include <ops/declarable/headers/common.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
/**
|
||||||
|
* This operation splits input string into pieces separated by delimiter
|
||||||
|
* PLEASE NOTE: This implementation is compatible with TF 1.x
|
||||||
|
*
|
||||||
|
* Input[0] - string to split
|
||||||
|
* Input[1] - delimiter
|
||||||
|
*
|
||||||
|
* Returns:
|
||||||
|
* Output[0] - indices tensor
|
||||||
|
* Output[1] - values tensor
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_compat_string_split)
|
||||||
|
DECLARE_CUSTOM_OP(compat_string_split, 2, 2, false, 0, 0);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This operation converts TF sparse array representation to dense NDArray
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_compat_sparse_to_dense)
|
||||||
|
DECLARE_CUSTOM_OP(compat_sparse_to_dense, 4, 1, false, 0, 0);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#endif //SAMEDIFF_COMPAT_H
|
|
@ -0,0 +1,42 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef SAMEDIFF_STRINGS_H
|
||||||
|
#define SAMEDIFF_STRINGS_H
|
||||||
|
|
||||||
|
#include <ops/declarable/headers/common.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
/**
|
||||||
|
* This operation splits input string into pieces separated by delimiter
|
||||||
|
*
|
||||||
|
* Input[0] - string to split
|
||||||
|
* Input[1] - delimiter
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_split_string)
|
||||||
|
DECLARE_CUSTOM_OP(split_string, 2, 1, true, 0, 0);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#endif //SAMEDIFF_STRINGS_H
|
|
@ -0,0 +1,44 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef LIBND4J_UTILS_H
|
||||||
|
#define LIBND4J_UTILS_H
|
||||||
|
|
||||||
|
#include <ops/declarable/headers/common.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
/**
|
||||||
|
* This operation prints out NDArray content, either on host or device.
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_print_variable)
|
||||||
|
DECLARE_CUSTOM_OP(print_variable, 1, 1, true, 0, 0);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This operation prints out affinity & locality status of given NDArray
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_print_affinity)
|
||||||
|
DECLARE_CUSTOM_OP(print_affinity, 1, 1, true, 0, 0);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif //LIBND4J_UTILS_H
|
|
@ -0,0 +1,31 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/print_variable.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
void print_special(LaunchContext &ctx, const NDArray &array, const std::string &message) {
|
||||||
|
array.printIndexedBuffer(message.c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -40,15 +40,11 @@ __global__ void preluCuda(const void *vx, const Nd4jLong *xShapeInfo,
|
||||||
const auto y = reinterpret_cast<const Y*>(vy);
|
const auto y = reinterpret_cast<const Y*>(vy);
|
||||||
auto z = reinterpret_cast<X*>(vz);
|
auto z = reinterpret_cast<X*>(vz);
|
||||||
|
|
||||||
__shared__ Nd4jLong xzLen, totalThreads, *sharedMem;
|
__shared__ Nd4jLong xzLen;
|
||||||
__shared__ int xzRank, yRank;
|
__shared__ int xzRank, yRank;
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
extern __shared__ unsigned char shmem[];
|
|
||||||
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
|
|
||||||
|
|
||||||
xzLen = shape::length(xShapeInfo);
|
xzLen = shape::length(xShapeInfo);
|
||||||
totalThreads = gridDim.x * blockDim.x;
|
|
||||||
|
|
||||||
xzRank = shape::rank(xShapeInfo);
|
xzRank = shape::rank(xShapeInfo);
|
||||||
yRank = shape::rank(yShapeInfo);
|
yRank = shape::rank(yShapeInfo);
|
||||||
|
@ -56,18 +52,15 @@ __global__ void preluCuda(const void *vx, const Nd4jLong *xShapeInfo,
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
Nd4jLong* coords = sharedMem + threadIdx.x * xzRank;
|
Nd4jLong coords[MAX_RANK];
|
||||||
|
|
||||||
for (int i = tid; i < xzLen; i += totalThreads) {
|
|
||||||
|
|
||||||
|
for (int i = tid; i < xzLen; i += blockDim.x * gridDim.x) {
|
||||||
shape::index2coords(i, xShapeInfo, coords);
|
shape::index2coords(i, xShapeInfo, coords);
|
||||||
|
|
||||||
const auto xzOffset = shape::getOffset(xShapeInfo, coords);
|
const auto xzOffset = shape::getOffset(xShapeInfo, coords);
|
||||||
|
|
||||||
const auto xVal = x[xzOffset];
|
const auto xVal = x[xzOffset];
|
||||||
|
|
||||||
if(xVal < 0) {
|
if(xVal < 0) {
|
||||||
|
|
||||||
for (uint j = 0; j < yRank; ++j)
|
for (uint j = 0; j < yRank; ++j)
|
||||||
if(yShapeInfo[j + 1] == 1)
|
if(yShapeInfo[j + 1] == 1)
|
||||||
coords[j + 1] = 0;
|
coords[j + 1] = 0;
|
||||||
|
@ -82,7 +75,6 @@ __global__ void preluCuda(const void *vx, const Nd4jLong *xShapeInfo,
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
template<typename X, typename Y>
|
template<typename X, typename Y>
|
||||||
linkage void preluCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz) {
|
linkage void preluCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz) {
|
||||||
|
|
||||||
preluCuda<X, Y><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz);
|
preluCuda<X, Y><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -91,9 +83,9 @@ void prelu(nd4j::LaunchContext * context, const NDArray& input, const NDArray& a
|
||||||
|
|
||||||
PointersManager manager(context, "prelu");
|
PointersManager manager(context, "prelu");
|
||||||
|
|
||||||
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
const int threadsPerBlock = 256;
|
||||||
const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
const int blocksPerGrid = 512;
|
||||||
const int sharedMem = input.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
|
const int sharedMem = 512;
|
||||||
|
|
||||||
const auto xType = input.dataType();
|
const auto xType = input.dataType();
|
||||||
const auto yType = alpha.dataType();
|
const auto yType = alpha.dataType();
|
||||||
|
@ -119,13 +111,10 @@ __global__ linkage void preluBPCuda(const void *vIn, const Nd4jLong *inShapeI
|
||||||
auto dLdI = reinterpret_cast<Y*>(vdLdI);
|
auto dLdI = reinterpret_cast<Y*>(vdLdI);
|
||||||
auto dLdA = reinterpret_cast<Y*>(vdLdA);
|
auto dLdA = reinterpret_cast<Y*>(vdLdA);
|
||||||
|
|
||||||
__shared__ Nd4jLong inLen, totalThreads, *sharedMem;
|
__shared__ Nd4jLong inLen, totalThreads;
|
||||||
__shared__ int inRank, alphaRank;
|
__shared__ int inRank, alphaRank;
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
extern __shared__ unsigned char shmem[];
|
|
||||||
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
|
|
||||||
|
|
||||||
inLen = shape::length(inShapeInfo);
|
inLen = shape::length(inShapeInfo);
|
||||||
totalThreads = gridDim.x * blockDim.x;
|
totalThreads = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
|
@ -135,10 +124,9 @@ __global__ linkage void preluBPCuda(const void *vIn, const Nd4jLong *inShapeI
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
Nd4jLong* coords = sharedMem + threadIdx.x * inRank;
|
Nd4jLong coords[MAX_RANK];
|
||||||
|
|
||||||
for (int i = tid; i < inLen; i += totalThreads) {
|
for (int i = tid; i < inLen; i += totalThreads) {
|
||||||
|
|
||||||
shape::index2coords(i, inShapeInfo, coords);
|
shape::index2coords(i, inShapeInfo, coords);
|
||||||
|
|
||||||
const auto inOffset = shape::getOffset(inShapeInfo, coords);
|
const auto inOffset = shape::getOffset(inShapeInfo, coords);
|
||||||
|
@ -175,14 +163,13 @@ __host__ linkage void preluBPCudaLauncher(const int blocksPerGrid, const int thr
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void preluBP(nd4j::LaunchContext* context, const NDArray& input, const NDArray& alpha, const NDArray& dLdO, NDArray& dLdI, NDArray& dLdA) {
|
void preluBP(nd4j::LaunchContext* context, const NDArray& input, const NDArray& alpha, const NDArray& dLdO, NDArray& dLdI, NDArray& dLdA) {
|
||||||
|
dLdA.nullify();
|
||||||
dLdA.nullify();
|
|
||||||
|
|
||||||
PointersManager manager(context, "preluBP");
|
PointersManager manager(context, "preluBP");
|
||||||
|
|
||||||
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
const int threadsPerBlock = 256;
|
||||||
const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
const int blocksPerGrid = 512;
|
||||||
const int sharedMem = input.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
|
const int sharedMem = 512;
|
||||||
|
|
||||||
const auto xType = input.dataType();
|
const auto xType = input.dataType();
|
||||||
const auto zType = alpha.dataType();
|
const auto zType = alpha.dataType();
|
||||||
|
|
|
@ -0,0 +1,61 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/print_variable.h>
|
||||||
|
#include <helpers/PointersManager.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
template <typename T>
|
||||||
|
static _CUDA_G void print_device(const void *special, const Nd4jLong *shapeInfo) {
|
||||||
|
auto length = shape::length(shapeInfo);
|
||||||
|
auto x = reinterpret_cast<const T*>(special);
|
||||||
|
|
||||||
|
// TODO: add formatting here
|
||||||
|
printf("[");
|
||||||
|
|
||||||
|
for (uint64_t e = 0; e < length; e++) {
|
||||||
|
printf("%f", (float) x[shape::getIndexOffset(e, shapeInfo)]);
|
||||||
|
|
||||||
|
if (e < length - 1)
|
||||||
|
printf(", ");
|
||||||
|
}
|
||||||
|
|
||||||
|
printf("]\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static _CUDA_H void exec_print_device(LaunchContext &ctx, const void *special, const Nd4jLong *shapeInfo) {
|
||||||
|
print_device<T><<<1, 1, 1024, *ctx.getCudaStream()>>>(special, shapeInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
void print_special(LaunchContext &ctx, const NDArray &array, const std::string &message) {
|
||||||
|
NDArray::prepareSpecialUse({}, {&array});
|
||||||
|
|
||||||
|
PointersManager pm(&ctx, "print_device");
|
||||||
|
BUILD_SINGLE_SELECTOR(array.dataType(), exec_print_device, (ctx, array.getSpecialBuffer(), array.getSpecialShapeInfo()), LIBND4J_TYPES)
|
||||||
|
pm.synchronize();
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({}, {&array});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -41,6 +41,9 @@
|
||||||
#include <helpers/DebugHelper.h>
|
#include <helpers/DebugHelper.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
|
|
||||||
|
#include <DebugHelper.h>
|
||||||
|
|
||||||
#endif // CUDACC
|
#endif // CUDACC
|
||||||
|
|
||||||
#endif // LIBND4J_HELPERS_H
|
#endif // LIBND4J_HELPERS_H
|
||||||
|
|
|
@ -0,0 +1,123 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/sparse_to_dense.h>
|
||||||
|
#include <helpers/StringUtils.h>
|
||||||
|
#include <helpers/ShapeUtils.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
template <typename X, typename I>
|
||||||
|
static void fill_(const void *vvalues, const void *vindices, void *voutput, const Nd4jLong *zShapeInfo, uint8_t rank, uint64_t length) {
|
||||||
|
auto values = reinterpret_cast<const X*>(vvalues);
|
||||||
|
auto indices = reinterpret_cast<const I*>(vindices);
|
||||||
|
auto output = reinterpret_cast<X*>(voutput);
|
||||||
|
|
||||||
|
Nd4jLong coords[MAX_RANK];
|
||||||
|
uint64_t pos = 0;
|
||||||
|
for (uint64_t e = 0L; e < length; e++) {
|
||||||
|
// indices come in blocks
|
||||||
|
for (uint8_t p = 0; p < rank; p++) {
|
||||||
|
coords[p] = indices[pos++];
|
||||||
|
}
|
||||||
|
|
||||||
|
// fill output at given coords with sparse value
|
||||||
|
output[shape::getOffset(zShapeInfo, coords)] = values[e];
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
void compat_sparse_to_dense(const NDArray &values, const NDArray &indices, NDArray *def, NDArray &output) {
|
||||||
|
// make sure host buffer is updated
|
||||||
|
values.syncToHost();
|
||||||
|
indices.syncToHost();
|
||||||
|
|
||||||
|
auto rank = output.rankOf();
|
||||||
|
|
||||||
|
if (output.isS()) {
|
||||||
|
// string case is not so trivial, since elements might, and probably will, have different sizes
|
||||||
|
auto numValues = values.lengthOf();
|
||||||
|
auto numElements = output.lengthOf();
|
||||||
|
|
||||||
|
// first of all we calculate final buffer sizes and offsets
|
||||||
|
auto defaultLength = def == nullptr ? 0 : StringUtils::byteLength(*def);
|
||||||
|
auto valuesLength = StringUtils::byteLength(values);
|
||||||
|
auto bufferLength = defaultLength * (output.lengthOf() - numValues) + valuesLength;
|
||||||
|
auto headerLength = ShapeUtils::stringBufferHeaderRequirements(numElements);
|
||||||
|
|
||||||
|
// now we make sure our output buffer can hold results
|
||||||
|
output.dataBuffer()->expand( bufferLength + headerLength);
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> outputCoords(rank);
|
||||||
|
std::vector<Nd4jLong> valueCoords(rank);
|
||||||
|
|
||||||
|
auto offsetsBuffer = output.bufferAsT<Nd4jLong>();
|
||||||
|
auto dataBuffer = reinterpret_cast<uint8_t*>(offsetsBuffer + output.lengthOf());
|
||||||
|
|
||||||
|
offsetsBuffer[0] = 0;
|
||||||
|
|
||||||
|
// getting initial value coords
|
||||||
|
for (int e = 0; e < rank; e++)
|
||||||
|
valueCoords[e] = indices.e<Nd4jLong>(e);
|
||||||
|
|
||||||
|
// write results individually
|
||||||
|
for (uint64_t e = 0; e < numElements; e++) {
|
||||||
|
auto vIndex = shape::coords2index(output.shapeInfo(), valueCoords.data());
|
||||||
|
auto cLength = 0L;
|
||||||
|
std::string str;
|
||||||
|
if (vIndex == e) {
|
||||||
|
// we're writing down sparse value here
|
||||||
|
str = values.e<std::string>(e);
|
||||||
|
} else {
|
||||||
|
// we're writing down default value if it exists
|
||||||
|
if (def != nullptr)
|
||||||
|
str = def->e<std::string>(0);
|
||||||
|
else
|
||||||
|
str = "";
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: make it unicode compliant
|
||||||
|
memcpy(&dataBuffer[offsetsBuffer[e]], str.c_str(), str.length());
|
||||||
|
|
||||||
|
// writing down offset
|
||||||
|
offsetsBuffer[e+1] = cLength;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// numeric case is trivial, since all elements have equal sizes
|
||||||
|
|
||||||
|
// write out default values, if they are present
|
||||||
|
if (def != nullptr) {
|
||||||
|
output.assign(def);
|
||||||
|
|
||||||
|
// make sure output is synced back
|
||||||
|
output.syncToHost();
|
||||||
|
}
|
||||||
|
|
||||||
|
// write out values
|
||||||
|
BUILD_DOUBLE_SELECTOR(values.dataType(), indices.dataType(), fill_, (values.getBuffer(), indices.getBuffer(), output.buffer(), output.getShapeInfo(), rank, values.lengthOf()), LIBND4J_TYPES, INDEXING_TYPES);
|
||||||
|
}
|
||||||
|
// copy back to device, if there's any
|
||||||
|
output.syncToDevice();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,34 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef LIBND4J_PRINT_VARIABLE_H
|
||||||
|
#define LIBND4J_PRINT_VARIABLE_H
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/helpers.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
void print_special(LaunchContext &ctx, const NDArray &array, const std::string &message = {});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif //LIBND4J_PRINT_VARIABLE_H
|
|
@ -0,0 +1,34 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef SAMEDIFF_SPARSE_TO_DENSE_H
|
||||||
|
#define SAMEDIFF_SPARSE_TO_DENSE_H
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/helpers.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
void compat_sparse_to_dense(const NDArray &values, const NDArray &indices, NDArray *def, NDArray &output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif //SAMEDIFF_SPARSE_TO_DENSE_H
|
|
@ -634,7 +634,7 @@
|
||||||
#define BROADCAST(NAME) nd4j::BroadcastOpsTuple::custom(nd4j::scalar::NAME, nd4j::pairwise::NAME, nd4j::broadcast::NAME)
|
#define BROADCAST(NAME) nd4j::BroadcastOpsTuple::custom(nd4j::scalar::NAME, nd4j::pairwise::NAME, nd4j::broadcast::NAME)
|
||||||
#define BROADCAST_BOOL(NAME) nd4j::BroadcastBoolOpsTuple::custom(nd4j::scalar::NAME, nd4j::pairwise::NAME, nd4j::broadcast::NAME)
|
#define BROADCAST_BOOL(NAME) nd4j::BroadcastBoolOpsTuple::custom(nd4j::scalar::NAME, nd4j::pairwise::NAME, nd4j::broadcast::NAME)
|
||||||
|
|
||||||
|
#define ALL_STRINGS nd4j::DataType::UTF8, nd4j::DataType::UTF16, nd4j::DataType::UTF32
|
||||||
#define ALL_INDICES nd4j::DataType::INT32, nd4j::DataType::INT64
|
#define ALL_INDICES nd4j::DataType::INT32, nd4j::DataType::INT64
|
||||||
#define ALL_INTS nd4j::DataType::INT8, nd4j::DataType::UINT8, nd4j::DataType::INT16, nd4j::DataType::UINT16, nd4j::DataType::INT32, nd4j::DataType::UINT32, nd4j::DataType::INT64, nd4j::DataType::UINT64
|
#define ALL_INTS nd4j::DataType::INT8, nd4j::DataType::UINT8, nd4j::DataType::INT16, nd4j::DataType::UINT16, nd4j::DataType::INT32, nd4j::DataType::UINT32, nd4j::DataType::INT64, nd4j::DataType::UINT64
|
||||||
#define ALL_FLOATS nd4j::DataType::HALF, nd4j::DataType::FLOAT32, nd4j::DataType::DOUBLE, nd4j::DataType::BFLOAT16
|
#define ALL_FLOATS nd4j::DataType::HALF, nd4j::DataType::FLOAT32, nd4j::DataType::DOUBLE, nd4j::DataType::BFLOAT16
|
||||||
|
|
|
@ -810,9 +810,10 @@ TEST_F(DeclarableOpsTests12, pullRows_1) {
|
||||||
#ifdef __CUDABLAS__
|
#ifdef __CUDABLAS__
|
||||||
nativeStart[1] = (x.getContext()->getCudaStream());
|
nativeStart[1] = (x.getContext()->getCudaStream());
|
||||||
#endif
|
#endif
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(),
|
OpaqueDataBuffer zBuf(z.dataBuffer());
|
||||||
z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
pullRows(nativeStart, &xBuf, x.getShapeInfo(), x.getSpecialShapeInfo(),
|
||||||
|
&zBuf, z.getShapeInfo(), z.specialShapeInfo(),
|
||||||
4, pidx,
|
4, pidx,
|
||||||
xTadPack.platformShapeInfo(), xTadPack.platformOffsets(),
|
xTadPack.platformShapeInfo(), xTadPack.platformOffsets(),
|
||||||
zTadPack.platformShapeInfo(), zTadPack.platformOffsets());
|
zTadPack.platformShapeInfo(), zTadPack.platformOffsets());
|
||||||
|
@ -844,8 +845,10 @@ TEST_F(DeclarableOpsTests12, pullRows_2) {
|
||||||
#ifdef __CUDABLAS__
|
#ifdef __CUDABLAS__
|
||||||
nativeStart[1] = (x.getContext()->getCudaStream());
|
nativeStart[1] = (x.getContext()->getCudaStream());
|
||||||
#endif
|
#endif
|
||||||
pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
OpaqueDataBuffer zBuf(z.dataBuffer());
|
||||||
|
pullRows(nativeStart, &xBuf, x.getShapeInfo(), x.specialShapeInfo(),
|
||||||
|
&zBuf, z.getShapeInfo(), z.specialShapeInfo(),
|
||||||
4, pidx,
|
4, pidx,
|
||||||
xTadPack.platformShapeInfo(), xTadPack.platformOffsets(),
|
xTadPack.platformShapeInfo(), xTadPack.platformOffsets(),
|
||||||
zTadPack.platformShapeInfo(), zTadPack.platformOffsets());
|
zTadPack.platformShapeInfo(), zTadPack.platformOffsets());
|
||||||
|
|
|
@ -0,0 +1,94 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <NDArray.h>
|
||||||
|
#include <ops/ops.h>
|
||||||
|
#include <GradCheck.h>
|
||||||
|
#include <array>
|
||||||
|
|
||||||
|
|
||||||
|
using namespace nd4j;
|
||||||
|
|
||||||
|
|
||||||
|
class DeclarableOpsTests17 : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
DeclarableOpsTests17() {
|
||||||
|
printf("\n");
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests17, test_sparse_to_dense_1) {
|
||||||
|
auto values = NDArrayFactory::create<float>({1.f, 2.f, 3.f});
|
||||||
|
auto shape = NDArrayFactory::create<Nd4jLong>({3, 3});
|
||||||
|
auto ranges = NDArrayFactory::create<Nd4jLong>({0,0, 1,1, 2,2});
|
||||||
|
auto def = NDArrayFactory::create<float>(0.f);
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {1.f,0.f,0.f, 0.f,2.f,0.f, 0.f,0.f,3.f});
|
||||||
|
|
||||||
|
|
||||||
|
nd4j::ops::compat_sparse_to_dense op;
|
||||||
|
auto result = op.execute({&ranges, &shape, &values, &def}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests17, test_sparse_to_dense_2) {
|
||||||
|
auto values = NDArrayFactory::string('c', {3}, {"alpha", "beta", "gamma"});
|
||||||
|
auto shape = NDArrayFactory::create<Nd4jLong>({3, 3});
|
||||||
|
auto ranges = NDArrayFactory::create<Nd4jLong>({0,0, 1,1, 2,2});
|
||||||
|
auto def = NDArrayFactory::string("d");
|
||||||
|
auto exp = NDArrayFactory::string('c', {3, 3}, {"alpha","d","d", "d","beta","d", "d","d","gamma"});
|
||||||
|
|
||||||
|
|
||||||
|
nd4j::ops::compat_sparse_to_dense op;
|
||||||
|
auto result = op.execute({&ranges, &shape, &values, &def}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests17, test_compat_string_split_1) {
|
||||||
|
auto x = NDArrayFactory::string('c', {2}, {"first string", "second"});
|
||||||
|
auto delimiter = NDArrayFactory::string(" ");
|
||||||
|
|
||||||
|
auto exp0 = NDArrayFactory::create<Nd4jLong>({0,0, 0,1, 1,0});
|
||||||
|
auto exp1 = NDArrayFactory::string('c', {3}, {"first", "string", "second"});
|
||||||
|
|
||||||
|
nd4j::ops::compat_string_split op;
|
||||||
|
auto result = op.execute({&x, &delimiter}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
ASSERT_EQ(2, result->size());
|
||||||
|
|
||||||
|
auto z0 = result->at(0);
|
||||||
|
auto z1 = result->at(1);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp0.isSameShape(z0));
|
||||||
|
ASSERT_TRUE(exp1.isSameShape(z1));
|
||||||
|
|
||||||
|
ASSERT_EQ(exp0, *z0);
|
||||||
|
ASSERT_EQ(exp1, *z1);
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
|
@ -0,0 +1,52 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <NDArray.h>
|
||||||
|
#include <ops/ops.h>
|
||||||
|
#include <GradCheck.h>
|
||||||
|
#include <array>
|
||||||
|
|
||||||
|
|
||||||
|
using namespace nd4j;
|
||||||
|
|
||||||
|
|
||||||
|
class DeclarableOpsTests18 : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
DeclarableOpsTests18() {
|
||||||
|
printf("\n");
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests18, test_bitcast_1) {
|
||||||
|
auto x = NDArrayFactory::create<double>(0.23028551377579154);
|
||||||
|
auto z = NDArrayFactory::create<Nd4jLong>(0);
|
||||||
|
auto e = NDArrayFactory::create<Nd4jLong>(4597464930322771456L);
|
||||||
|
|
||||||
|
nd4j::ops::bitcast op;
|
||||||
|
auto status = op.execute({&x}, {&z}, {}, {(Nd4jLong) nd4j::DataType::INT64}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, z);
|
||||||
|
}
|
|
@ -0,0 +1,40 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <NDArray.h>
|
||||||
|
#include <ops/ops.h>
|
||||||
|
#include <GradCheck.h>
|
||||||
|
#include <array>
|
||||||
|
|
||||||
|
|
||||||
|
using namespace nd4j;
|
||||||
|
|
||||||
|
|
||||||
|
class DeclarableOpsTests19 : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
DeclarableOpsTests19() {
|
||||||
|
printf("\n");
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
};
|
|
@ -834,12 +834,17 @@ TEST_F(JavaInteropTests, Test_Reduce3_EdgeCase) {
|
||||||
auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {0,1});
|
auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {0,1});
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&z}, {&x, &y, &dims});
|
NDArray::prepareSpecialUse({&z}, {&x, &y, &dims});
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer yBuf(y.dataBuffer());
|
||||||
|
OpaqueDataBuffer zBuf(z.dataBuffer());
|
||||||
|
OpaqueDataBuffer dimBuf(dims.dataBuffer());
|
||||||
|
|
||||||
execReduce3Tad(extraPointers, 2, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
execReduce3Tad(extraPointers, 2, &xBuf, x.shapeInfo(), x.specialShapeInfo(),
|
||||||
nullptr,
|
nullptr,
|
||||||
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
&yBuf, y.shapeInfo(), y.specialShapeInfo(),
|
||||||
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
&zBuf, z.shapeInfo(), z.specialShapeInfo(),
|
||||||
dims.buffer(), dims.shapeInfo(), dims.specialBuffer(), dims.specialShapeInfo(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
|
&dimBuf, dims.shapeInfo(), dims.specialShapeInfo(), packX.platformShapeInfo(),
|
||||||
|
packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
|
||||||
|
|
||||||
NDArray::registerSpecialUse({&z}, {&x, &y, &dims});
|
NDArray::registerSpecialUse({&z}, {&x, &y, &dims});
|
||||||
|
|
||||||
|
@ -981,10 +986,14 @@ TEST_F(JavaInteropTests, Test_Mixed_Add_1) {
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&arrayZ}, {&arrayX, &arrayY});
|
NDArray::prepareSpecialUse({&arrayZ}, {&arrayX, &arrayY});
|
||||||
|
|
||||||
|
OpaqueDataBuffer xBuf(arrayX.dataBuffer());
|
||||||
|
OpaqueDataBuffer yBuf(arrayY.dataBuffer());
|
||||||
|
OpaqueDataBuffer zBuf(arrayZ.dataBuffer());
|
||||||
|
|
||||||
execPairwiseTransform(nullptr, pairwise::Add,
|
execPairwiseTransform(nullptr, pairwise::Add,
|
||||||
arrayX.buffer(), arrayX.shapeInfo(), arrayX.getSpecialBuffer(), arrayX.getSpecialShapeInfo(),
|
&xBuf, arrayX.shapeInfo(), arrayX.getSpecialShapeInfo(),
|
||||||
arrayY.buffer(), arrayY.shapeInfo(), arrayY.getSpecialBuffer(), arrayY.getSpecialShapeInfo(),
|
&yBuf, arrayY.shapeInfo(), arrayY.getSpecialShapeInfo(),
|
||||||
arrayZ.buffer(), arrayZ.shapeInfo(), arrayZ.getSpecialBuffer(), arrayZ.getSpecialShapeInfo(),
|
&zBuf, arrayZ.shapeInfo(), arrayZ.getSpecialShapeInfo(),
|
||||||
nullptr);
|
nullptr);
|
||||||
|
|
||||||
NDArray::registerSpecialUse({&arrayZ}, {&arrayX, &arrayY});
|
NDArray::registerSpecialUse({&arrayZ}, {&arrayX, &arrayY});
|
||||||
|
@ -1220,10 +1229,10 @@ TEST_F(JavaInteropTests, test_bfloat16_rng) {
|
||||||
auto z = NDArrayFactory::create<bfloat16>('c', {10});
|
auto z = NDArrayFactory::create<bfloat16>('c', {10});
|
||||||
RandomGenerator rng(119, 323841120L);
|
RandomGenerator rng(119, 323841120L);
|
||||||
bfloat16 args[2] = {(bfloat16) 0.0f, (bfloat16) 1.0f};
|
bfloat16 args[2] = {(bfloat16) 0.0f, (bfloat16) 1.0f};
|
||||||
execRandom(nullptr, nd4j::random::Ops::UniformDistribution, &rng, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), args);
|
OpaqueDataBuffer zBuf(z.dataBuffer());
|
||||||
|
execRandom(nullptr, nd4j::random::Ops::UniformDistribution, &rng, &zBuf, z.shapeInfo(), z.specialShapeInfo(), args);
|
||||||
|
|
||||||
//z.printIndexedBuffer("z");
|
//z.printIndexedBuffer("z");
|
||||||
|
|
||||||
ASSERT_TRUE(z.sumNumber().e<float>(0) > 0);
|
ASSERT_TRUE(z.sumNumber().e<float>(0) > 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1267,6 +1276,64 @@ TEST_F(JavaInteropTests, test_size_dtype_1) {
|
||||||
ASSERT_EQ(e, z);
|
ASSERT_EQ(e, z);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(JavaInteropTests, test_expandable_array_op_1) {
|
||||||
|
auto x = NDArrayFactory::string('c', {2}, {"first string", "second"});
|
||||||
|
auto d = NDArrayFactory::string(" ");
|
||||||
|
|
||||||
|
auto z0 = NDArrayFactory::create<Nd4jLong>('c', {6});
|
||||||
|
auto z1 = NDArrayFactory::string('c', {3}, {"", "", ""});
|
||||||
|
|
||||||
|
auto exp0 = NDArrayFactory::create<Nd4jLong>({0,0, 0,1, 1,0});
|
||||||
|
auto exp1 = NDArrayFactory::string('c', {3}, {"first", "string", "second"});
|
||||||
|
|
||||||
|
InteropDataBuffer iz0(z0.dataBuffer());
|
||||||
|
InteropDataBuffer iz1(z1.dataBuffer());
|
||||||
|
|
||||||
|
Context ctx(1);
|
||||||
|
ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
|
||||||
|
ctx.setInputArray(1, d.buffer(), d.shapeInfo(), d.specialBuffer(), d.specialShapeInfo());
|
||||||
|
ctx.setOutputArray(0, &iz0, z0.shapeInfo(), z0.specialShapeInfo());
|
||||||
|
ctx.setOutputArray(1, &iz1, z1.shapeInfo(), z1.specialShapeInfo());
|
||||||
|
|
||||||
|
nd4j::ops::compat_string_split op;
|
||||||
|
auto status = op.execute(&ctx);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_EQ(exp0, z0);
|
||||||
|
ASSERT_EQ(exp1, z1);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(JavaInteropTests, test_workspace_backed_arrays_1) {
|
||||||
|
if (!Environment::getInstance()->isCPU())
|
||||||
|
return;
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {4, 3, 4, 4});
|
||||||
|
auto y = NDArrayFactory::create<double>('c', {4, 3, 3, 3});
|
||||||
|
auto z = NDArrayFactory::create<double>('c', {4, 3, 4, 4});
|
||||||
|
|
||||||
|
double buffer[2048];
|
||||||
|
|
||||||
|
InteropDataBuffer ix(0, DataType::DOUBLE, false);
|
||||||
|
InteropDataBuffer iy(0, DataType::DOUBLE, false);
|
||||||
|
InteropDataBuffer iz(0, DataType::DOUBLE, false);
|
||||||
|
|
||||||
|
// we're imitating workspace-managed array here
|
||||||
|
ix.setPrimary(buffer + 64, x.lengthOf());
|
||||||
|
iy.setPrimary(buffer + 64 + x.lengthOf(), y.lengthOf());
|
||||||
|
iz.setPrimary(buffer + 64 + x.lengthOf() + y.lengthOf(), z.lengthOf());
|
||||||
|
|
||||||
|
Context ctx(1);
|
||||||
|
ctx.setInputArray(0, &ix, x.shapeInfo(), x.specialShapeInfo());
|
||||||
|
ctx.setInputArray(1, &iy, y.shapeInfo(), y.specialShapeInfo());
|
||||||
|
ctx.setOutputArray(0, &iz, z.shapeInfo(), z.specialShapeInfo());
|
||||||
|
|
||||||
|
ctx.setIArguments({2, 2, 1, 1, 0, 0, 1, 1, 0, 0, 0});
|
||||||
|
|
||||||
|
nd4j::ops::maxpool2d_bp op;
|
||||||
|
auto status = op.execute(&ctx);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
|
TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
|
||||||
auto pl = nd4j::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb");
|
auto pl = nd4j::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb");
|
||||||
|
|
|
@ -470,12 +470,16 @@ TEST_F(LegacyOpsTests, Reduce3_2) {
|
||||||
auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {1});
|
auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {1});
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&z}, {&x, &y, &dim});
|
NDArray::prepareSpecialUse({&z}, {&x, &y, &dim});
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer yBuf(y.dataBuffer());
|
||||||
|
OpaqueDataBuffer zBuf(z.dataBuffer());
|
||||||
|
OpaqueDataBuffer dimBuf(dim.dataBuffer());
|
||||||
|
|
||||||
execReduce3Tad(extraPointers, reduce3::CosineSimilarity,
|
execReduce3Tad(extraPointers, reduce3::CosineSimilarity,
|
||||||
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
|
||||||
nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
nullptr, &yBuf, y.shapeInfo(), y.specialShapeInfo(),
|
||||||
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
&zBuf, z.shapeInfo(), z.specialShapeInfo(),
|
||||||
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
|
&dimBuf, dim.shapeInfo(), dim.specialShapeInfo(),
|
||||||
packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
|
packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
|
||||||
|
|
||||||
NDArray::registerSpecialUse({&z}, {&x, &y, &dim});
|
NDArray::registerSpecialUse({&z}, {&x, &y, &dim});
|
||||||
|
@ -506,14 +510,17 @@ TEST_F(LegacyOpsTests, Reduce3_3) {
|
||||||
auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {1});
|
auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {1});
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&z}, {&x, &y, &dim});
|
NDArray::prepareSpecialUse({&z}, {&x, &y, &dim});
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer yBuf(y.dataBuffer());
|
||||||
|
OpaqueDataBuffer zBuf(z.dataBuffer());
|
||||||
|
OpaqueDataBuffer dimBuf(dim.dataBuffer());
|
||||||
|
|
||||||
execReduce3Tad(extraPointers, reduce3::CosineDistance,
|
execReduce3Tad(extraPointers, reduce3::CosineDistance,
|
||||||
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
|
||||||
nullptr,
|
nullptr,
|
||||||
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
&yBuf, y.shapeInfo(), y.specialShapeInfo(),
|
||||||
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
&zBuf, z.shapeInfo(), z.specialShapeInfo(),
|
||||||
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
|
&dimBuf, dim.shapeInfo(), dim.specialShapeInfo(),
|
||||||
packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
|
packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
|
||||||
ASSERT_EQ(e, z);
|
ASSERT_EQ(e, z);
|
||||||
NDArray::registerSpecialUse({&z}, {&x, &y, &dim});
|
NDArray::registerSpecialUse({&z}, {&x, &y, &dim});
|
||||||
|
@ -543,14 +550,17 @@ TEST_F(LegacyOpsTests, Reduce3_4) {
|
||||||
auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {1});
|
auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {1});
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&z}, {&x, &y, &dim});
|
NDArray::prepareSpecialUse({&z}, {&x, &y, &dim});
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer yBuf(y.dataBuffer());
|
||||||
|
OpaqueDataBuffer zBuf(z.dataBuffer());
|
||||||
|
OpaqueDataBuffer dimBuf(dim.dataBuffer());
|
||||||
|
|
||||||
execReduce3Tad(extraPointers, reduce3::CosineDistance,
|
execReduce3Tad(extraPointers, reduce3::CosineDistance,
|
||||||
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
|
||||||
nullptr,
|
nullptr,
|
||||||
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
&yBuf, y.shapeInfo(), y.specialShapeInfo(),
|
||||||
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
&zBuf, z.shapeInfo(), z.specialShapeInfo(),
|
||||||
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
|
&dimBuf, dim.shapeInfo(), dim.specialShapeInfo(),
|
||||||
packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
|
packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
|
||||||
|
|
||||||
// z.printIndexedBuffer("z");
|
// z.printIndexedBuffer("z");
|
||||||
|
@ -583,13 +593,16 @@ TEST_F(LegacyOpsTests, Reduce3_5) {
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&z}, {&x, &y, &dim});
|
NDArray::prepareSpecialUse({&z}, {&x, &y, &dim});
|
||||||
|
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer yBuf(y.dataBuffer());
|
||||||
|
OpaqueDataBuffer zBuf(z.dataBuffer());
|
||||||
|
OpaqueDataBuffer dimBuf(dim.dataBuffer());
|
||||||
|
|
||||||
execReduce3Tad(extraPointers, reduce3::CosineDistance,
|
execReduce3Tad(extraPointers, reduce3::CosineDistance,
|
||||||
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
|
||||||
nullptr,
|
nullptr,
|
||||||
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
&yBuf, y.shapeInfo(), y.specialShapeInfo(), &zBuf, z.shapeInfo(), z.specialShapeInfo(),
|
||||||
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
&dimBuf, dim.shapeInfo(), dim.specialShapeInfo(),
|
||||||
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
|
|
||||||
packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
|
packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
|
||||||
|
|
||||||
NDArray::registerSpecialUse({&z}, {&x, &y, &dim});
|
NDArray::registerSpecialUse({&z}, {&x, &y, &dim});
|
||||||
|
@ -615,10 +628,15 @@ TEST_F(LegacyOpsTests, test_Reduce3_All_1) {
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&z}, {&x, &y});
|
NDArray::prepareSpecialUse({&z}, {&x, &y});
|
||||||
|
|
||||||
execReduce3All(extraPointers, reduce3::EuclideanDistance, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
OpaqueDataBuffer yBuf(y.dataBuffer());
|
||||||
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
OpaqueDataBuffer zBuf(z.dataBuffer());
|
||||||
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
|
OpaqueDataBuffer dimBuf(dim.dataBuffer());
|
||||||
|
|
||||||
|
execReduce3All(extraPointers, reduce3::EuclideanDistance, &xBuf, x.shapeInfo(), x.specialShapeInfo(),
|
||||||
|
nullptr, &yBuf, y.shapeInfo(), y.specialShapeInfo(),
|
||||||
|
&zBuf, z.shapeInfo(), z.specialShapeInfo(),
|
||||||
|
&dimBuf, dim.shapeInfo(), dim.specialShapeInfo(),
|
||||||
tadPackX.platformShapeInfo(), tadPackX.platformOffsets(),
|
tadPackX.platformShapeInfo(), tadPackX.platformOffsets(),
|
||||||
tadPackY.platformShapeInfo(), tadPackY.platformOffsets());
|
tadPackY.platformShapeInfo(), tadPackY.platformOffsets());
|
||||||
|
|
||||||
|
@ -730,13 +748,16 @@ TEST_F(LegacyOpsTests, test_legacy_reduce_empty_4) {
|
||||||
auto z = NDArrayFactory::create<float>('c', {0, 2});
|
auto z = NDArrayFactory::create<float>('c', {0, 2});
|
||||||
auto e = NDArrayFactory::create<float>('c', {0, 2});
|
auto e = NDArrayFactory::create<float>('c', {0, 2});
|
||||||
|
|
||||||
|
InteropDataBuffer xdb(x.dataBuffer());
|
||||||
|
InteropDataBuffer ddb(d.dataBuffer());
|
||||||
|
InteropDataBuffer zdb(z.dataBuffer());
|
||||||
|
|
||||||
|
|
||||||
::execReduceSame2(nullptr, reduce::SameOps::Sum,
|
::execReduceSame2(nullptr, reduce::SameOps::Sum,
|
||||||
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
&xdb, x.shapeInfo(), x.specialShapeInfo(),
|
||||||
nullptr,
|
nullptr,
|
||||||
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
&zdb, z.shapeInfo(), z.specialShapeInfo(),
|
||||||
d.buffer(), d.shapeInfo(), d.specialBuffer(), d.specialShapeInfo());
|
&ddb, d.shapeInfo(), d.specialShapeInfo());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -119,13 +119,15 @@ TEST_F(NativeOpsTests, ExecIndexReduce_1) {
|
||||||
#ifdef __CUDABLAS__
|
#ifdef __CUDABLAS__
|
||||||
printf("Unsupported for cuda now.\n");
|
printf("Unsupported for cuda now.\n");
|
||||||
#else
|
#else
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer expBuf(exp.dataBuffer());
|
||||||
::execIndexReduceScalar(nullptr,
|
::execIndexReduceScalar(nullptr,
|
||||||
indexreduce::IndexMax,
|
indexreduce::IndexMax,
|
||||||
x.buffer(), x.shapeInfo(),
|
&xBuf, x.shapeInfo(),
|
||||||
nullptr, nullptr,
|
|
||||||
nullptr,
|
nullptr,
|
||||||
exp.buffer(), exp.shapeInfo(),
|
nullptr,
|
||||||
nullptr, nullptr);
|
&expBuf, exp.shapeInfo(),
|
||||||
|
nullptr);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.e<Nd4jLong>(0) == 4LL);
|
ASSERT_TRUE(exp.e<Nd4jLong>(0) == 4LL);
|
||||||
#endif
|
#endif
|
||||||
|
@ -140,15 +142,18 @@ TEST_F(NativeOpsTests, ExecIndexReduce_2) {
|
||||||
printf("Unsupported for cuda now.\n");
|
printf("Unsupported for cuda now.\n");
|
||||||
#else
|
#else
|
||||||
NDArray dimension = NDArrayFactory::create<int>({});
|
NDArray dimension = NDArrayFactory::create<int>({});
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer expBuf(exp.dataBuffer());
|
||||||
|
OpaqueDataBuffer dimensionBuf(dimension.dataBuffer());
|
||||||
|
|
||||||
::execIndexReduce(nullptr,
|
::execIndexReduce(nullptr,
|
||||||
indexreduce::IndexMax,
|
indexreduce::IndexMax,
|
||||||
x.buffer(), x.shapeInfo(),
|
&xBuf, x.shapeInfo(), nullptr,
|
||||||
nullptr, nullptr,
|
|
||||||
nullptr,
|
nullptr,
|
||||||
exp.buffer(), exp.shapeInfo(),
|
&expBuf, exp.shapeInfo(),
|
||||||
nullptr, nullptr,
|
nullptr,
|
||||||
dimension.buffer(), dimension.shapeInfo(),
|
&dimensionBuf, dimension.shapeInfo(),
|
||||||
nullptr, nullptr);
|
nullptr);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.e<Nd4jLong>(0) == 24LL);
|
ASSERT_TRUE(exp.e<Nd4jLong>(0) == 24LL);
|
||||||
#endif
|
#endif
|
||||||
|
@ -166,16 +171,21 @@ TEST_F(NativeOpsTests, ExecBroadcast_1) {
|
||||||
#else
|
#else
|
||||||
auto dimension = NDArrayFactory::create<int>('c', {1}, {1});
|
auto dimension = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer yBuf(y.dataBuffer());
|
||||||
|
OpaqueDataBuffer expBuf(exp.dataBuffer());
|
||||||
|
OpaqueDataBuffer dimBuf(dimension.dataBuffer());
|
||||||
|
|
||||||
::execBroadcast(nullptr,
|
::execBroadcast(nullptr,
|
||||||
broadcast::Add,
|
broadcast::Add,
|
||||||
x.buffer(), x.shapeInfo(),
|
&xBuf, x.shapeInfo(),
|
||||||
nullptr, nullptr,
|
nullptr,
|
||||||
y.buffer(), y.shapeInfo(),
|
&yBuf, y.shapeInfo(),
|
||||||
nullptr, nullptr,
|
nullptr,
|
||||||
exp.buffer(), exp.shapeInfo(),
|
&expBuf, exp.shapeInfo(),
|
||||||
nullptr, nullptr,
|
nullptr,
|
||||||
dimension.buffer(), dimension.shapeInfo(),
|
&dimBuf, dimension.shapeInfo(),
|
||||||
nullptr, nullptr);
|
nullptr);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.e<float>(0) == 3.);
|
ASSERT_TRUE(exp.e<float>(0) == 3.);
|
||||||
#endif
|
#endif
|
||||||
|
@ -194,17 +204,18 @@ printf("Unsupported for cuda now.\n");
|
||||||
int dimd = 0;
|
int dimd = 0;
|
||||||
auto dimension = NDArrayFactory::create<int>('c', {1}, {dimd});
|
auto dimension = NDArrayFactory::create<int>('c', {1}, {dimd});
|
||||||
|
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer yBuf(y.dataBuffer());
|
||||||
|
OpaqueDataBuffer expBuf(exp.dataBuffer());
|
||||||
|
OpaqueDataBuffer dimBuf(dimension.dataBuffer());
|
||||||
|
|
||||||
::execBroadcastBool(nullptr,
|
::execBroadcastBool(nullptr,
|
||||||
broadcast::EqualTo,
|
broadcast::EqualTo,
|
||||||
x.buffer(), x.shapeInfo(),
|
&xBuf, x.shapeInfo(), nullptr,
|
||||||
nullptr, nullptr,
|
&yBuf, y.shapeInfo(), nullptr,
|
||||||
y.buffer(), y.shapeInfo(),
|
&expBuf, exp.shapeInfo(), nullptr, nullptr,
|
||||||
nullptr, nullptr,
|
&dimBuf, dimension.shapeInfo(),
|
||||||
exp.buffer(), exp.shapeInfo(),
|
nullptr);
|
||||||
nullptr, nullptr,
|
|
||||||
nullptr,
|
|
||||||
dimension.buffer(), dimension.shapeInfo(),
|
|
||||||
nullptr, nullptr);
|
|
||||||
ASSERT_TRUE(exp.e<bool>(1) && !exp.e<bool>(0));
|
ASSERT_TRUE(exp.e<bool>(1) && !exp.e<bool>(0));
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -219,14 +230,15 @@ TEST_F(NativeOpsTests, ExecPairwise_1) {
|
||||||
#ifdef __CUDABLAS__
|
#ifdef __CUDABLAS__
|
||||||
printf("Unsupported for cuda now.\n");
|
printf("Unsupported for cuda now.\n");
|
||||||
#else
|
#else
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer yBuf(y.dataBuffer());
|
||||||
|
OpaqueDataBuffer expBuf(exp.dataBuffer());
|
||||||
|
|
||||||
::execPairwiseTransform(nullptr,
|
::execPairwiseTransform(nullptr,
|
||||||
pairwise::Add,
|
pairwise::Add,
|
||||||
x.buffer(), x.shapeInfo(),
|
&xBuf, x.shapeInfo(), nullptr,
|
||||||
nullptr, nullptr,
|
&yBuf, y.shapeInfo(), nullptr,
|
||||||
y.buffer(), y.shapeInfo(),
|
&expBuf, exp.shapeInfo(), nullptr,
|
||||||
nullptr, nullptr,
|
|
||||||
exp.buffer(), exp.shapeInfo(),
|
|
||||||
nullptr, nullptr,
|
|
||||||
nullptr);
|
nullptr);
|
||||||
ASSERT_TRUE(exp.e<float>(5) == 8.);
|
ASSERT_TRUE(exp.e<float>(5) == 8.);
|
||||||
#endif
|
#endif
|
||||||
|
@ -243,14 +255,15 @@ TEST_F(NativeOpsTests, ExecPairwise_2) {
|
||||||
#ifdef __CUDABLAS__
|
#ifdef __CUDABLAS__
|
||||||
printf("Unsupported for cuda now.\n");
|
printf("Unsupported for cuda now.\n");
|
||||||
#else
|
#else
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer yBuf(y.dataBuffer());
|
||||||
|
OpaqueDataBuffer expBuf(exp.dataBuffer());
|
||||||
|
|
||||||
::execPairwiseTransformBool(nullptr,
|
::execPairwiseTransformBool(nullptr,
|
||||||
pairwise::And,
|
pairwise::And,
|
||||||
x.buffer(), x.shapeInfo(),
|
&xBuf, x.shapeInfo(), nullptr,
|
||||||
nullptr, nullptr,
|
&yBuf, y.shapeInfo(), nullptr,
|
||||||
y.buffer(), y.shapeInfo(),
|
&expBuf, exp.shapeInfo(), nullptr,
|
||||||
nullptr, nullptr,
|
|
||||||
exp.buffer(), exp.shapeInfo(),
|
|
||||||
nullptr, nullptr,
|
|
||||||
nullptr);
|
nullptr);
|
||||||
ASSERT_TRUE(exp.e<bool>(5) && !exp.e<bool>(4));
|
ASSERT_TRUE(exp.e<bool>(5) && !exp.e<bool>(4));
|
||||||
#endif
|
#endif
|
||||||
|
@ -266,14 +279,14 @@ TEST_F(NativeOpsTests, ReduceTest_1) {
|
||||||
printf("Unsupported for cuda now.\n");
|
printf("Unsupported for cuda now.\n");
|
||||||
#else
|
#else
|
||||||
auto dimension = NDArrayFactory::create<int>('c', {1}, {1});
|
auto dimension = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer expBuf(exp.dataBuffer());
|
||||||
|
|
||||||
::execReduceFloat(nullptr,
|
::execReduceFloat(nullptr,
|
||||||
reduce::Mean,
|
reduce::Mean,
|
||||||
x.buffer(), x.shapeInfo(),
|
&xBuf, x.shapeInfo(), nullptr,
|
||||||
nullptr, nullptr,
|
|
||||||
nullptr,
|
nullptr,
|
||||||
exp.buffer(), exp.shapeInfo(),
|
&expBuf, exp.shapeInfo(), nullptr);
|
||||||
nullptr, nullptr);
|
|
||||||
// x.printIndexedBuffer("Input");
|
// x.printIndexedBuffer("Input");
|
||||||
// exp.printIndexedBuffer("Reduce Mean");
|
// exp.printIndexedBuffer("Reduce Mean");
|
||||||
ASSERT_TRUE(exp.e<float>(0) == 13.);
|
ASSERT_TRUE(exp.e<float>(0) == 13.);
|
||||||
|
@ -289,14 +302,14 @@ TEST_F(NativeOpsTests, ReduceTest_2) {
|
||||||
#ifdef __CUDABLAS__
|
#ifdef __CUDABLAS__
|
||||||
printf("Unsupported for cuda now.\n");
|
printf("Unsupported for cuda now.\n");
|
||||||
#else
|
#else
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer expBuf(exp.dataBuffer());
|
||||||
|
|
||||||
::execReduceSame(nullptr,
|
::execReduceSame(nullptr,
|
||||||
reduce::Sum,
|
reduce::Sum,
|
||||||
x.buffer(), x.shapeInfo(),
|
&xBuf, x.shapeInfo(), nullptr,
|
||||||
nullptr, nullptr,
|
|
||||||
nullptr,
|
nullptr,
|
||||||
exp.buffer(), exp.shapeInfo(),
|
&expBuf, exp.shapeInfo(), nullptr);
|
||||||
nullptr, nullptr);
|
|
||||||
// x.printIndexedBuffer("Input");
|
// x.printIndexedBuffer("Input");
|
||||||
// exp.printIndexedBuffer("Reduce Sum");
|
// exp.printIndexedBuffer("Reduce Sum");
|
||||||
ASSERT_TRUE(exp.e<float>(0) == 325.);
|
ASSERT_TRUE(exp.e<float>(0) == 325.);
|
||||||
|
@ -312,14 +325,14 @@ TEST_F(NativeOpsTests, ReduceTest_3) {
|
||||||
#ifdef __CUDABLAS__
|
#ifdef __CUDABLAS__
|
||||||
printf("Unsupported for cuda now.\n");
|
printf("Unsupported for cuda now.\n");
|
||||||
#else
|
#else
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer expBuf(exp.dataBuffer());
|
||||||
|
|
||||||
::execReduceBool(nullptr,
|
::execReduceBool(nullptr,
|
||||||
reduce::All,
|
reduce::All,
|
||||||
x.buffer(), x.shapeInfo(),
|
&xBuf, x.shapeInfo(), nullptr,
|
||||||
nullptr, nullptr,
|
|
||||||
nullptr,
|
nullptr,
|
||||||
exp.buffer(), exp.shapeInfo(),
|
&expBuf, exp.shapeInfo(), nullptr);
|
||||||
nullptr, nullptr);
|
|
||||||
// x.printIndexedBuffer("Input");
|
// x.printIndexedBuffer("Input");
|
||||||
// exp.printIndexedBuffer("Reduce All");
|
// exp.printIndexedBuffer("Reduce All");
|
||||||
ASSERT_TRUE(exp.e<bool>(0) == true);
|
ASSERT_TRUE(exp.e<bool>(0) == true);
|
||||||
|
@ -335,14 +348,14 @@ TEST_F(NativeOpsTests, ReduceTest_4) {
|
||||||
#ifdef __CUDABLAS__
|
#ifdef __CUDABLAS__
|
||||||
printf("Unsupported for cuda now.\n");
|
printf("Unsupported for cuda now.\n");
|
||||||
#else
|
#else
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer expBuf(exp.dataBuffer());
|
||||||
|
|
||||||
::execReduceLong(nullptr,
|
::execReduceLong(nullptr,
|
||||||
reduce::CountNonZero,
|
reduce::CountNonZero,
|
||||||
x.buffer(), x.shapeInfo(),
|
&xBuf, x.shapeInfo(), nullptr,
|
||||||
nullptr, nullptr,
|
|
||||||
nullptr,
|
nullptr,
|
||||||
exp.buffer(), exp.shapeInfo(),
|
&expBuf, exp.shapeInfo(), nullptr);
|
||||||
nullptr, nullptr);
|
|
||||||
// x.printIndexedBuffer("Input");
|
// x.printIndexedBuffer("Input");
|
||||||
// exp.printIndexedBuffer("Reduce CountNonZero");
|
// exp.printIndexedBuffer("Reduce CountNonZero");
|
||||||
ASSERT_TRUE(exp.e<Nd4jLong>(0) == 25LL);
|
ASSERT_TRUE(exp.e<Nd4jLong>(0) == 25LL);
|
||||||
|
@ -359,15 +372,16 @@ TEST_F(NativeOpsTests, ReduceTest_5) {
|
||||||
printf("Unsupported for cuda now.\n");
|
printf("Unsupported for cuda now.\n");
|
||||||
#else
|
#else
|
||||||
auto dimension = NDArrayFactory::create<int>({0, 1});
|
auto dimension = NDArrayFactory::create<int>({0, 1});
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer expBuf(exp.dataBuffer());
|
||||||
|
OpaqueDataBuffer dimBuf(dimension.dataBuffer());
|
||||||
|
|
||||||
::execReduceLong2(nullptr,
|
::execReduceLong2(nullptr,
|
||||||
reduce::CountNonZero,
|
reduce::CountNonZero,
|
||||||
x.buffer(), x.shapeInfo(),
|
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
|
||||||
nullptr, nullptr,
|
|
||||||
nullptr,
|
nullptr,
|
||||||
exp.buffer(), exp.shapeInfo(),
|
&expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
|
||||||
nullptr, nullptr,
|
&dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo());
|
||||||
dimension.buffer(), dimension.shapeInfo(),
|
|
||||||
dimension.specialBuffer(), dimension.specialShapeInfo());
|
|
||||||
// x.printIndexedBuffer("Input");
|
// x.printIndexedBuffer("Input");
|
||||||
// exp.printIndexedBuffer("Reduce CountNonZero");
|
// exp.printIndexedBuffer("Reduce CountNonZero");
|
||||||
ASSERT_TRUE(exp.e<Nd4jLong>(0) == 25LL);
|
ASSERT_TRUE(exp.e<Nd4jLong>(0) == 25LL);
|
||||||
|
@ -389,15 +403,17 @@ TEST_F(NativeOpsTests, ReduceTest_6) {
|
||||||
x.p(10, 0); x.p(11, 0);
|
x.p(10, 0); x.p(11, 0);
|
||||||
x.p(15, 0); x.p(16, 0); x.p(17, 0);
|
x.p(15, 0); x.p(16, 0); x.p(17, 0);
|
||||||
x.p(20, 0); x.p(21, 0); x.p(22, 0); x.p(23, 0);
|
x.p(20, 0); x.p(21, 0); x.p(22, 0); x.p(23, 0);
|
||||||
|
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer dimBuf(dimension.dataBuffer());
|
||||||
|
OpaqueDataBuffer expBuf(exp.dataBuffer());
|
||||||
|
|
||||||
::execReduceLong2(nullptr,
|
::execReduceLong2(nullptr,
|
||||||
reduce::CountNonZero,
|
reduce::CountNonZero,
|
||||||
x.buffer(), x.shapeInfo(),
|
&xBuf, x.shapeInfo(), nullptr,
|
||||||
nullptr, nullptr,
|
|
||||||
nullptr,
|
nullptr,
|
||||||
exp.buffer(), exp.shapeInfo(),
|
&expBuf, exp.shapeInfo(), nullptr,
|
||||||
nullptr, nullptr,
|
&dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo());
|
||||||
dimension.buffer(), dimension.shapeInfo(),
|
|
||||||
dimension.specialBuffer(), dimension.specialShapeInfo());
|
|
||||||
// x.printIndexedBuffer("Input");
|
// x.printIndexedBuffer("Input");
|
||||||
// exp.printIndexedBuffer("Reduce CountNonZero");
|
// exp.printIndexedBuffer("Reduce CountNonZero");
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
@ -421,15 +437,16 @@ TEST_F(NativeOpsTests, ReduceTest_7) {
|
||||||
x.linspace(1.0);
|
x.linspace(1.0);
|
||||||
x.syncToDevice();
|
x.syncToDevice();
|
||||||
dimension.syncToHost();
|
dimension.syncToHost();
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer dimBuf(dimension.dataBuffer());
|
||||||
|
OpaqueDataBuffer expBuf(exp.dataBuffer());
|
||||||
|
|
||||||
::execReduceFloat2(extra,
|
::execReduceFloat2(extra,
|
||||||
reduce::Mean,
|
reduce::Mean,
|
||||||
x.buffer(), x.shapeInfo(),
|
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
|
||||||
x.specialBuffer(), x.specialShapeInfo(),
|
|
||||||
nullptr,
|
nullptr,
|
||||||
exp.buffer(), exp.shapeInfo(),
|
&expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
|
||||||
exp.specialBuffer(), exp.specialShapeInfo(),
|
&dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo());
|
||||||
dimension.buffer(), dimension.shapeInfo(),
|
|
||||||
dimension.specialBuffer(), dimension.specialShapeInfo());
|
|
||||||
// x.printIndexedBuffer("Input");
|
// x.printIndexedBuffer("Input");
|
||||||
// exp.printIndexedBuffer("Reduce Mean");
|
// exp.printIndexedBuffer("Reduce Mean");
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
@ -453,16 +470,16 @@ TEST_F(NativeOpsTests, ReduceTest_8) {
|
||||||
x.syncToDevice();
|
x.syncToDevice();
|
||||||
|
|
||||||
dimension.syncToHost();
|
dimension.syncToHost();
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer dimBuf(dimension.dataBuffer());
|
||||||
|
OpaqueDataBuffer zBuf(z.dataBuffer());
|
||||||
|
|
||||||
::execReduceSame2(extra,
|
::execReduceSame2(extra,
|
||||||
reduce::Sum,
|
reduce::Sum,
|
||||||
x.buffer(), x.shapeInfo(),
|
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
|
||||||
x.specialBuffer(), x.specialShapeInfo(),
|
|
||||||
nullptr,
|
nullptr,
|
||||||
z.buffer(), z.shapeInfo(),
|
&zBuf, z.shapeInfo(), z.specialShapeInfo(),
|
||||||
z.specialBuffer(), z.specialShapeInfo(),
|
&dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo());
|
||||||
dimension.buffer(), dimension.shapeInfo(),
|
|
||||||
dimension.specialBuffer(), dimension.specialShapeInfo());
|
|
||||||
// x.printIndexedBuffer("Input");
|
// x.printIndexedBuffer("Input");
|
||||||
// exp.printIndexedBuffer("Reduce Sum");
|
// exp.printIndexedBuffer("Reduce Sum");
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
@ -485,15 +502,17 @@ TEST_F(NativeOpsTests, ReduceTest_9) {
|
||||||
x.syncToDevice();
|
x.syncToDevice();
|
||||||
|
|
||||||
dimension.syncToHost();
|
dimension.syncToHost();
|
||||||
|
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer dimBuf(dimension.dataBuffer());
|
||||||
|
OpaqueDataBuffer expBuf(exp.dataBuffer());
|
||||||
|
|
||||||
::execReduceBool2(extra,
|
::execReduceBool2(extra,
|
||||||
reduce::All,
|
reduce::All,
|
||||||
x.buffer(), x.shapeInfo(),
|
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
|
||||||
x.specialBuffer(), x.specialShapeInfo(),
|
|
||||||
nullptr,
|
nullptr,
|
||||||
exp.buffer(), exp.shapeInfo(),
|
&expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
|
||||||
exp.specialBuffer(), exp.specialShapeInfo(),
|
&dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo());
|
||||||
dimension.buffer(), dimension.shapeInfo(),
|
|
||||||
dimension.specialBuffer(), dimension.specialShapeInfo());
|
|
||||||
// x.printIndexedBuffer("Input");
|
// x.printIndexedBuffer("Input");
|
||||||
// exp.printIndexedBuffer("Reduce All");
|
// exp.printIndexedBuffer("Reduce All");
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
@ -518,15 +537,16 @@ TEST_F(NativeOpsTests, Reduce3Test_1) {
|
||||||
y.assign(2.);
|
y.assign(2.);
|
||||||
x.syncToDevice();
|
x.syncToDevice();
|
||||||
|
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer yBuf(y.dataBuffer());
|
||||||
|
OpaqueDataBuffer expBuf(exp.dataBuffer());
|
||||||
|
|
||||||
::execReduce3(extra,
|
::execReduce3(extra,
|
||||||
reduce3::Dot,
|
reduce3::Dot,
|
||||||
x.buffer(), x.shapeInfo(),
|
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
|
||||||
x.specialBuffer(), x.specialShapeInfo(),
|
|
||||||
nullptr,
|
nullptr,
|
||||||
y.buffer(), y.shapeInfo(),
|
&yBuf, y.shapeInfo(), y.specialShapeInfo(),
|
||||||
y.specialBuffer(), y.specialShapeInfo(),
|
&expBuf, exp.shapeInfo(), exp.specialShapeInfo());
|
||||||
exp.buffer(), exp.shapeInfo(),
|
|
||||||
exp.specialBuffer(), exp.specialShapeInfo());
|
|
||||||
//z.printIndexedBuffer("Z");
|
//z.printIndexedBuffer("Z");
|
||||||
//exp.printIndexedBuffer("Reduce3 Dot");
|
//exp.printIndexedBuffer("Reduce3 Dot");
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
@ -551,15 +571,16 @@ TEST_F(NativeOpsTests, Reduce3Test_2) {
|
||||||
y.assign(2.);
|
y.assign(2.);
|
||||||
x.syncToDevice();
|
x.syncToDevice();
|
||||||
|
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer yBuf(y.dataBuffer());
|
||||||
|
OpaqueDataBuffer expBuf(exp.dataBuffer());
|
||||||
|
|
||||||
::execReduce3Scalar(extra,
|
::execReduce3Scalar(extra,
|
||||||
reduce3::Dot,
|
reduce3::Dot,
|
||||||
x.buffer(), x.shapeInfo(),
|
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
|
||||||
x.specialBuffer(), x.specialShapeInfo(),
|
|
||||||
nullptr,
|
nullptr,
|
||||||
y.buffer(), y.shapeInfo(),
|
&yBuf, y.shapeInfo(), y.specialShapeInfo(),
|
||||||
y.specialBuffer(), y.specialShapeInfo(),
|
&expBuf, exp.shapeInfo(), exp.specialShapeInfo());
|
||||||
exp.buffer(), exp.shapeInfo(),
|
|
||||||
exp.specialBuffer(), exp.specialShapeInfo());
|
|
||||||
// x.printIndexedBuffer("Input");
|
// x.printIndexedBuffer("Input");
|
||||||
// exp.printIndexedBuffer("Reduce3 Dot");
|
// exp.printIndexedBuffer("Reduce3 Dot");
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
@ -585,17 +606,18 @@ TEST_F(NativeOpsTests, Reduce3Test_3) {
|
||||||
x.syncToDevice();
|
x.syncToDevice();
|
||||||
dimension.syncToHost();
|
dimension.syncToHost();
|
||||||
|
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer yBuf(y.dataBuffer());
|
||||||
|
OpaqueDataBuffer expBuf(exp.dataBuffer());
|
||||||
|
OpaqueDataBuffer dimBuf(dimension.dataBuffer());
|
||||||
|
|
||||||
::execReduce3Tad(extra,
|
::execReduce3Tad(extra,
|
||||||
reduce3::Dot,
|
reduce3::Dot,
|
||||||
x.buffer(), x.shapeInfo(),
|
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
|
||||||
x.specialBuffer(), x.specialShapeInfo(),
|
|
||||||
nullptr,
|
nullptr,
|
||||||
y.buffer(), y.shapeInfo(),
|
&yBuf, y.shapeInfo(), y.specialShapeInfo(),
|
||||||
y.specialBuffer(), y.specialShapeInfo(),
|
&expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
|
||||||
exp.buffer(), exp.shapeInfo(),
|
&dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo(),
|
||||||
exp.specialBuffer(), exp.specialShapeInfo(),
|
|
||||||
dimension.buffer(), dimension.shapeInfo(),
|
|
||||||
dimension.specialBuffer(), dimension.specialShapeInfo(),
|
|
||||||
nullptr, nullptr, nullptr, nullptr);
|
nullptr, nullptr, nullptr, nullptr);
|
||||||
// x.printIndexedBuffer("Input");
|
// x.printIndexedBuffer("Input");
|
||||||
// exp.printIndexedBuffer("Reduce All");
|
// exp.printIndexedBuffer("Reduce All");
|
||||||
|
@ -630,17 +652,18 @@ TEST_F(NativeOpsTests, Reduce3Test_4) {
|
||||||
auto hTADShapeInfoY = tadPackY.primaryShapeInfo();
|
auto hTADShapeInfoY = tadPackY.primaryShapeInfo();
|
||||||
auto hTADOffsetsY = tadPackY.primaryOffsets();
|
auto hTADOffsetsY = tadPackY.primaryOffsets();
|
||||||
|
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer yBuf(y.dataBuffer());
|
||||||
|
OpaqueDataBuffer expBuf(exp.dataBuffer());
|
||||||
|
OpaqueDataBuffer dimBuf(dimension.dataBuffer());
|
||||||
|
|
||||||
::execReduce3All(extra,
|
::execReduce3All(extra,
|
||||||
reduce3::Dot,
|
reduce3::Dot,
|
||||||
x.buffer(), x.shapeInfo(),
|
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
|
||||||
x.specialBuffer(), x.specialShapeInfo(),
|
|
||||||
nullptr,
|
nullptr,
|
||||||
y.buffer(), y.shapeInfo(),
|
&yBuf, y.shapeInfo(), y.specialShapeInfo(),
|
||||||
y.specialBuffer(), y.specialShapeInfo(),
|
&expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
|
||||||
exp.buffer(), exp.shapeInfo(),
|
&dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo(),
|
||||||
exp.specialBuffer(), exp.specialShapeInfo(),
|
|
||||||
dimension.buffer(), dimension.shapeInfo(),
|
|
||||||
dimension.specialBuffer(), dimension.specialShapeInfo(),
|
|
||||||
hTADShapeInfoX, hTADOffsetsX, hTADShapeInfoY, hTADOffsetsY);
|
hTADShapeInfoX, hTADOffsetsX, hTADShapeInfoY, hTADOffsetsY);
|
||||||
// x.printIndexedBuffer("Input");
|
// x.printIndexedBuffer("Input");
|
||||||
// exp.printIndexedBuffer("Reduce All");
|
// exp.printIndexedBuffer("Reduce All");
|
||||||
|
@ -667,14 +690,16 @@ TEST_F(NativeOpsTests, ScalarTest_1) {
|
||||||
//y.assign(2.);
|
//y.assign(2.);
|
||||||
x.syncToDevice();
|
x.syncToDevice();
|
||||||
z.syncToDevice();
|
z.syncToDevice();
|
||||||
|
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer yBuf(y.dataBuffer());
|
||||||
|
OpaqueDataBuffer expBuf(exp.dataBuffer());
|
||||||
|
|
||||||
::execScalar(extra,
|
::execScalar(extra,
|
||||||
scalar::Multiply,
|
scalar::Multiply,
|
||||||
x.buffer(), x.shapeInfo(),
|
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
|
||||||
x.specialBuffer(), x.specialShapeInfo(),
|
&expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
|
||||||
exp.buffer(), exp.shapeInfo(),
|
&yBuf, y.shapeInfo(), y.specialShapeInfo(), nullptr);
|
||||||
exp.specialBuffer(), exp.specialShapeInfo(),
|
|
||||||
y.buffer(), y.shapeInfo(),
|
|
||||||
y.specialBuffer(), y.specialShapeInfo(), nullptr);
|
|
||||||
// x.printIndexedBuffer("Input");
|
// x.printIndexedBuffer("Input");
|
||||||
// exp.printIndexedBuffer("Reduce All");
|
// exp.printIndexedBuffer("Reduce All");
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
@ -700,14 +725,16 @@ TEST_F(NativeOpsTests, ScalarTest_2) {
|
||||||
//y.assign(2.);
|
//y.assign(2.);
|
||||||
x.syncToDevice();
|
x.syncToDevice();
|
||||||
z.syncToDevice();
|
z.syncToDevice();
|
||||||
|
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer yBuf(y.dataBuffer());
|
||||||
|
OpaqueDataBuffer expBuf(exp.dataBuffer());
|
||||||
|
|
||||||
::execScalarBool(extra,
|
::execScalarBool(extra,
|
||||||
scalar::GreaterThan,
|
scalar::GreaterThan,
|
||||||
x.buffer(), x.shapeInfo(),
|
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
|
||||||
x.specialBuffer(), x.specialShapeInfo(),
|
&expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
|
||||||
exp.buffer(), exp.shapeInfo(),
|
&yBuf, y.shapeInfo(), y.specialShapeInfo(), nullptr);
|
||||||
exp.specialBuffer(), exp.specialShapeInfo(),
|
|
||||||
y.buffer(), y.shapeInfo(),
|
|
||||||
y.specialBuffer(), y.specialShapeInfo(), nullptr);
|
|
||||||
// x.printIndexedBuffer("Input");
|
// x.printIndexedBuffer("Input");
|
||||||
// exp.printIndexedBuffer("Reduce All");
|
// exp.printIndexedBuffer("Reduce All");
|
||||||
ASSERT_TRUE(exp.e<bool>(5) == z.e<bool>(5) && exp.e<bool>(15) != z.e<bool>(15));
|
ASSERT_TRUE(exp.e<bool>(5) == z.e<bool>(5) && exp.e<bool>(15) != z.e<bool>(15));
|
||||||
|
@ -726,13 +753,14 @@ TEST_F(NativeOpsTests, SummaryStatsScalarTest_1) {
|
||||||
printf("Unsupported for CUDA platform yet.\n");
|
printf("Unsupported for CUDA platform yet.\n");
|
||||||
return;
|
return;
|
||||||
#endif
|
#endif
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer expBuf(exp.dataBuffer());
|
||||||
|
|
||||||
::execSummaryStatsScalar(extra,
|
::execSummaryStatsScalar(extra,
|
||||||
variance::SummaryStatsVariance,
|
variance::SummaryStatsVariance,
|
||||||
x.buffer(), x.shapeInfo(),
|
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
|
||||||
x.specialBuffer(), x.specialShapeInfo(),
|
|
||||||
nullptr,
|
nullptr,
|
||||||
exp.buffer(), exp.shapeInfo(),
|
&expBuf, exp.shapeInfo(), exp.specialShapeInfo(), false);
|
||||||
exp.specialBuffer(), exp.specialShapeInfo(), false);
|
|
||||||
// x.printIndexedBuffer("Input");
|
// x.printIndexedBuffer("Input");
|
||||||
// exp.printIndexedBuffer("Standard Variance");
|
// exp.printIndexedBuffer("Standard Variance");
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
@ -751,13 +779,13 @@ TEST_F(NativeOpsTests, SummaryStatsScalarTest_2) {
|
||||||
printf("Unsupported for CUDA platform yet.\n");
|
printf("Unsupported for CUDA platform yet.\n");
|
||||||
return;
|
return;
|
||||||
#endif
|
#endif
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer expBuf(exp.dataBuffer());
|
||||||
::execSummaryStats(extra,
|
::execSummaryStats(extra,
|
||||||
variance::SummaryStatsVariance,
|
variance::SummaryStatsVariance,
|
||||||
x.buffer(), x.shapeInfo(),
|
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
|
||||||
x.specialBuffer(), x.specialShapeInfo(),
|
|
||||||
nullptr,
|
nullptr,
|
||||||
exp.buffer(), exp.shapeInfo(),
|
&expBuf, exp.shapeInfo(), exp.specialShapeInfo(), false);
|
||||||
exp.specialBuffer(), exp.specialShapeInfo(), false);
|
|
||||||
// x.printIndexedBuffer("Input");
|
// x.printIndexedBuffer("Input");
|
||||||
// exp.printIndexedBuffer("Standard Variance");
|
// exp.printIndexedBuffer("Standard Variance");
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
@ -777,15 +805,16 @@ TEST_F(NativeOpsTests, SummaryStatsScalarTest_3) {
|
||||||
return;
|
return;
|
||||||
#endif
|
#endif
|
||||||
auto dimensions = NDArrayFactory::create<int>({0, 1});
|
auto dimensions = NDArrayFactory::create<int>({0, 1});
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer expBuf(exp.dataBuffer());
|
||||||
|
OpaqueDataBuffer dimBuf(dimensions.dataBuffer());
|
||||||
|
|
||||||
::execSummaryStatsTad(extra,
|
::execSummaryStatsTad(extra,
|
||||||
variance::SummaryStatsVariance,
|
variance::SummaryStatsVariance,
|
||||||
x.buffer(), x.shapeInfo(),
|
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
|
||||||
x.specialBuffer(), x.specialShapeInfo(),
|
|
||||||
nullptr,
|
nullptr,
|
||||||
exp.buffer(), exp.shapeInfo(),
|
&expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
|
||||||
exp.specialBuffer(), exp.specialShapeInfo(),
|
&dimBuf, dimensions.shapeInfo(), dimensions.specialShapeInfo(),
|
||||||
dimensions.buffer(), dimensions.shapeInfo(),
|
|
||||||
dimensions.specialBuffer(), dimensions.specialShapeInfo(),
|
|
||||||
false,
|
false,
|
||||||
nullptr, nullptr);
|
nullptr, nullptr);
|
||||||
// x.printIndexedBuffer("Input");
|
// x.printIndexedBuffer("Input");
|
||||||
|
@ -807,13 +836,15 @@ TEST_F(NativeOpsTests, TransformTest_1) {
|
||||||
return;
|
return;
|
||||||
#endif
|
#endif
|
||||||
z.linspace(1.);
|
z.linspace(1.);
|
||||||
|
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer zBuf(z.dataBuffer());
|
||||||
|
OpaqueDataBuffer expBuf(exp.dataBuffer());
|
||||||
|
|
||||||
::execTransformFloat(extra,
|
::execTransformFloat(extra,
|
||||||
transform::Sqrt,
|
transform::Sqrt,
|
||||||
x.buffer(), x.shapeInfo(),
|
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
|
||||||
x.specialBuffer(), x.specialShapeInfo(),
|
&expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
|
||||||
|
|
||||||
exp.buffer(), exp.shapeInfo(),
|
|
||||||
exp.specialBuffer(), exp.specialShapeInfo(),
|
|
||||||
nullptr);
|
nullptr);
|
||||||
// x.printIndexedBuffer("Input");
|
// x.printIndexedBuffer("Input");
|
||||||
// exp.printIndexedBuffer("Sqrt is");
|
// exp.printIndexedBuffer("Sqrt is");
|
||||||
|
@ -834,13 +865,15 @@ TEST_F(NativeOpsTests, TransformTest_2) {
|
||||||
return;
|
return;
|
||||||
#endif
|
#endif
|
||||||
z.linspace(1.);
|
z.linspace(1.);
|
||||||
|
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer zBuf(z.dataBuffer());
|
||||||
|
OpaqueDataBuffer expBuf(exp.dataBuffer());
|
||||||
|
|
||||||
::execTransformSame(extra,
|
::execTransformSame(extra,
|
||||||
transform::Square,
|
transform::Square,
|
||||||
z.buffer(), z.shapeInfo(),
|
&zBuf, z.shapeInfo(), z.specialShapeInfo(),
|
||||||
z.specialBuffer(), z.specialShapeInfo(),
|
&expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
|
||||||
|
|
||||||
exp.buffer(), exp.shapeInfo(),
|
|
||||||
exp.specialBuffer(), exp.specialShapeInfo(),
|
|
||||||
nullptr);
|
nullptr);
|
||||||
// x.printIndexedBuffer("Input");
|
// x.printIndexedBuffer("Input");
|
||||||
// exp.printIndexedBuffer("Square is");
|
// exp.printIndexedBuffer("Square is");
|
||||||
|
@ -864,13 +897,14 @@ TEST_F(NativeOpsTests, TransformTest_3) {
|
||||||
z.assign(true);
|
z.assign(true);
|
||||||
x.p(24, -25);
|
x.p(24, -25);
|
||||||
z.p(24, false);
|
z.p(24, false);
|
||||||
|
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer expBuf(exp.dataBuffer());
|
||||||
|
|
||||||
::execTransformBool(extra,
|
::execTransformBool(extra,
|
||||||
transform::IsPositive,
|
transform::IsPositive,
|
||||||
x.buffer(), x.shapeInfo(),
|
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
|
||||||
x.specialBuffer(), x.specialShapeInfo(),
|
&expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
|
||||||
|
|
||||||
exp.buffer(), exp.shapeInfo(),
|
|
||||||
exp.specialBuffer(), exp.specialShapeInfo(),
|
|
||||||
nullptr);
|
nullptr);
|
||||||
// x.printIndexedBuffer("Input");
|
// x.printIndexedBuffer("Input");
|
||||||
// exp.printIndexedBuffer("IsPositive");
|
// exp.printIndexedBuffer("IsPositive");
|
||||||
|
@ -894,13 +928,13 @@ TEST_F(NativeOpsTests, TransformTest_4) {
|
||||||
return;
|
return;
|
||||||
#endif
|
#endif
|
||||||
//z.linspace(1.);
|
//z.linspace(1.);
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer expBuf(exp.dataBuffer());
|
||||||
|
|
||||||
::execTransformStrict(extra,
|
::execTransformStrict(extra,
|
||||||
transform::Cosine,
|
transform::Cosine,
|
||||||
x.buffer(), x.shapeInfo(),
|
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
|
||||||
x.specialBuffer(), x.specialShapeInfo(),
|
&expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
|
||||||
|
|
||||||
exp.buffer(), exp.shapeInfo(),
|
|
||||||
exp.specialBuffer(), exp.specialShapeInfo(),
|
|
||||||
nullptr);
|
nullptr);
|
||||||
// x.printIndexedBuffer("Input");
|
// x.printIndexedBuffer("Input");
|
||||||
// exp.printIndexedBuffer("Cosine");
|
// exp.printIndexedBuffer("Cosine");
|
||||||
|
@ -932,17 +966,18 @@ TEST_F(NativeOpsTests, ScalarTadTest_1) {
|
||||||
auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), dimensions, dimension.lengthOf());
|
auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), dimensions, dimension.lengthOf());
|
||||||
auto tadPackZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.shapeInfo(), dimensions, dimension.lengthOf());
|
auto tadPackZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.shapeInfo(), dimensions, dimension.lengthOf());
|
||||||
|
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer yBuf(y.dataBuffer());
|
||||||
|
OpaqueDataBuffer expBuf(exp.dataBuffer());
|
||||||
|
OpaqueDataBuffer dimBuf(dimension.dataBuffer());
|
||||||
|
|
||||||
::execScalarTad(extra,
|
::execScalarTad(extra,
|
||||||
scalar::Multiply,
|
scalar::Multiply,
|
||||||
x.buffer(), x.shapeInfo(),
|
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
|
||||||
x.specialBuffer(), x.specialShapeInfo(),
|
&expBuf, exp.shapeInfo(), exp.specialShapeInfo(),
|
||||||
exp.buffer(), exp.shapeInfo(),
|
&yBuf, y.shapeInfo(), y.specialShapeInfo(),
|
||||||
exp.specialBuffer(), exp.specialShapeInfo(),
|
|
||||||
y.buffer(), y.shapeInfo(),
|
|
||||||
y.specialBuffer(), y.specialShapeInfo(),
|
|
||||||
nullptr,
|
nullptr,
|
||||||
dimension.buffer(), dimension.shapeInfo(),
|
&dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo(),
|
||||||
dimension.specialBuffer(), dimension.specialShapeInfo(),
|
|
||||||
tadPackX.primaryShapeInfo(), tadPackX.primaryOffsets(), tadPackZ.primaryShapeInfo(), tadPackZ.primaryOffsets());
|
tadPackX.primaryShapeInfo(), tadPackX.primaryOffsets(), tadPackZ.primaryShapeInfo(), tadPackZ.primaryOffsets());
|
||||||
// x.printIndexedBuffer("Input");
|
// x.printIndexedBuffer("Input");
|
||||||
// exp.printIndexedBuffer("Reduce All");
|
// exp.printIndexedBuffer("Reduce All");
|
||||||
|
@ -977,17 +1012,21 @@ TEST_F(NativeOpsTests, ScalarTadTest_2) {
|
||||||
auto tadPackZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.shapeInfo(), dimensions, dimension.lengthOf());
|
auto tadPackZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.shapeInfo(), dimensions, dimension.lengthOf());
|
||||||
z.assign(true);
|
z.assign(true);
|
||||||
|
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer yBuf(y.dataBuffer());
|
||||||
|
OpaqueDataBuffer expBuf(exp.dataBuffer());
|
||||||
|
OpaqueDataBuffer dimBuf(dimension.dataBuffer());
|
||||||
|
|
||||||
::execScalarBoolTad(extra,
|
::execScalarBoolTad(extra,
|
||||||
scalar::And,
|
scalar::And,
|
||||||
x.buffer(), x.shapeInfo(),
|
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
|
||||||
x.specialBuffer(), x.specialShapeInfo(),
|
&expBuf, exp.shapeInfo(),
|
||||||
exp.buffer(), exp.shapeInfo(),
|
exp.specialShapeInfo(),
|
||||||
exp.specialBuffer(), exp.specialShapeInfo(),
|
&yBuf, y.shapeInfo(),
|
||||||
y.buffer(), y.shapeInfo(),
|
y.specialShapeInfo(),
|
||||||
y.specialBuffer(), y.specialShapeInfo(),
|
|
||||||
nullptr,
|
nullptr,
|
||||||
dimension.buffer(), dimension.shapeInfo(),
|
&dimBuf, dimension.shapeInfo(),
|
||||||
dimension.specialBuffer(), dimension.specialShapeInfo(),
|
dimension.specialShapeInfo(),
|
||||||
tadPackX.primaryShapeInfo(), tadPackX.primaryOffsets(), tadPackZ.primaryShapeInfo(), tadPackZ.primaryOffsets());
|
tadPackX.primaryShapeInfo(), tadPackX.primaryOffsets(), tadPackZ.primaryShapeInfo(), tadPackZ.primaryOffsets());
|
||||||
// x.printIndexedBuffer("Input");
|
// x.printIndexedBuffer("Input");
|
||||||
// exp.printIndexedBuffer("And");
|
// exp.printIndexedBuffer("And");
|
||||||
|
@ -1095,9 +1134,11 @@ TEST_F(NativeOpsTests, PullRowsTest_1) {
|
||||||
#ifdef __CUDABLAS__
|
#ifdef __CUDABLAS__
|
||||||
nativeStart[1] = (x.getContext()->getCudaStream());
|
nativeStart[1] = (x.getContext()->getCudaStream());
|
||||||
#endif
|
#endif
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer zBuf(z.dataBuffer());
|
||||||
|
|
||||||
pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(),
|
pullRows(nativeStart, &xBuf, x.getShapeInfo(), x.getSpecialShapeInfo(),
|
||||||
z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
&zBuf, z.getShapeInfo(), z.specialShapeInfo(),
|
||||||
4, pidx,
|
4, pidx,
|
||||||
xTadPack.platformShapeInfo(), xTadPack.platformOffsets(),
|
xTadPack.platformShapeInfo(), xTadPack.platformOffsets(),
|
||||||
zTadPack.platformShapeInfo(), zTadPack.platformOffsets());
|
zTadPack.platformShapeInfo(), zTadPack.platformOffsets());
|
||||||
|
@ -1250,7 +1291,9 @@ TEST_F(NativeOpsTests, RandomTest_1) {
|
||||||
#endif
|
#endif
|
||||||
graph::RandomGenerator rng(1023, 119);
|
graph::RandomGenerator rng(1023, 119);
|
||||||
double p = 0.5;
|
double p = 0.5;
|
||||||
::execRandom(extra, random::BernoulliDistribution, &rng, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &p);
|
OpaqueDataBuffer zBuf(z.dataBuffer());
|
||||||
|
|
||||||
|
::execRandom(extra, random::BernoulliDistribution, &rng, &zBuf, z.shapeInfo(), z.specialShapeInfo(), &p);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NativeOpsTests, RandomTest_2) {
|
TEST_F(NativeOpsTests, RandomTest_2) {
|
||||||
|
@ -1264,7 +1307,10 @@ TEST_F(NativeOpsTests, RandomTest_2) {
|
||||||
x.linspace(0, 0.01);
|
x.linspace(0, 0.01);
|
||||||
graph::RandomGenerator rng(1023, 119);
|
graph::RandomGenerator rng(1023, 119);
|
||||||
double p = 0.5;
|
double p = 0.5;
|
||||||
::execRandom2(extra, random::DropOut, &rng, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &p);
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer zBuf(z.dataBuffer());
|
||||||
|
|
||||||
|
::execRandom2(extra, random::DropOut, &rng, &xBuf, x.shapeInfo(), x.specialShapeInfo(), &zBuf, z.shapeInfo(), z.specialShapeInfo(), &p);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NativeOpsTests, RandomTest_3) {
|
TEST_F(NativeOpsTests, RandomTest_3) {
|
||||||
|
@ -1280,7 +1326,12 @@ TEST_F(NativeOpsTests, RandomTest_3) {
|
||||||
x.linspace(1, -0.01);
|
x.linspace(1, -0.01);
|
||||||
graph::RandomGenerator rng(1023, 119);
|
graph::RandomGenerator rng(1023, 119);
|
||||||
double p = 0.5;
|
double p = 0.5;
|
||||||
::execRandom3(extra, random::ProbablisticMerge, &rng, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &p);
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer yBuf(y.dataBuffer());
|
||||||
|
OpaqueDataBuffer zBuf(z.dataBuffer());
|
||||||
|
|
||||||
|
::execRandom3(extra, random::ProbablisticMerge, &rng, &xBuf, x.shapeInfo(), x.specialShapeInfo(), &yBuf,
|
||||||
|
y.shapeInfo(), y.specialShapeInfo(), &zBuf, z.shapeInfo(), z.specialShapeInfo(), &p);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NativeOpsTests, RandomTest_4) {
|
TEST_F(NativeOpsTests, RandomTest_4) {
|
||||||
|
@ -1316,6 +1367,10 @@ TEST_F(NativeOpsTests, SortTests_2) {
|
||||||
#ifdef __CUDABLAS__
|
#ifdef __CUDABLAS__
|
||||||
extras[1] = LaunchContext::defaultContext()->getCudaStream();
|
extras[1] = LaunchContext::defaultContext()->getCudaStream();
|
||||||
#endif
|
#endif
|
||||||
|
// OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
// OpaqueDataBuffer yBuf(y.dataBuffer());
|
||||||
|
// OpaqueDataBuffer expBuf(exp.dataBuffer());
|
||||||
|
// OpaqueDataBuffer dimBuf(exp.dataBuffer());
|
||||||
|
|
||||||
::sortByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
|
::sortByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
|
||||||
k.tickWriteDevice();
|
k.tickWriteDevice();
|
||||||
|
@ -1541,6 +1596,13 @@ TEST_F(NativeOpsTests, CalculateOutputShapeTests_2) {
|
||||||
::deleteShapeList((Nd4jPointer) shapeList);
|
::deleteShapeList((Nd4jPointer) shapeList);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(NativeOpsTests, interop_databuffer_tests_1) {
|
||||||
|
auto idb = ::allocateDataBuffer(100, 10, false);
|
||||||
|
auto ptr = ::dbPrimaryBuffer(idb);
|
||||||
|
::deleteDataBuffer(idb);
|
||||||
|
}
|
||||||
|
|
||||||
//Uncomment when needed only - massive calculations
|
//Uncomment when needed only - massive calculations
|
||||||
//TEST_F(NativeOpsTests, BenchmarkTests_1) {
|
//TEST_F(NativeOpsTests, BenchmarkTests_1) {
|
||||||
//
|
//
|
||||||
|
|
|
@ -90,4 +90,26 @@ TEST_F(StringTests, Basic_dup_1) {
|
||||||
ASSERT_EQ(f, z1);
|
ASSERT_EQ(f, z1);
|
||||||
|
|
||||||
delete dup;
|
delete dup;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(StringTests, byte_length_test_1) {
|
||||||
|
std::string f("alpha");
|
||||||
|
auto array = NDArrayFactory::string(f);
|
||||||
|
|
||||||
|
ASSERT_EQ(f.length(), StringUtils::byteLength(array));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(StringTests, byte_length_test_2) {
|
||||||
|
auto array = NDArrayFactory::string('c', {2}, {"alpha", "beta"});
|
||||||
|
|
||||||
|
ASSERT_EQ(9, StringUtils::byteLength(array));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(StringTests, test_split_1) {
|
||||||
|
auto split = StringUtils::split("alpha beta gamma", " ");
|
||||||
|
|
||||||
|
ASSERT_EQ(3, split.size());
|
||||||
|
ASSERT_EQ(std::string("alpha"), split[0]);
|
||||||
|
ASSERT_EQ(std::string("beta"), split[1]);
|
||||||
|
ASSERT_EQ(std::string("gamma"), split[2]);
|
||||||
}
|
}
|
|
@ -1,5 +1,6 @@
|
||||||
package org.nd4j.autodiff.listeners.debugging;
|
package org.nd4j.autodiff.listeners.debugging;
|
||||||
|
|
||||||
|
import lombok.val;
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||||
import org.nd4j.autodiff.listeners.At;
|
import org.nd4j.autodiff.listeners.At;
|
||||||
import org.nd4j.autodiff.listeners.BaseListener;
|
import org.nd4j.autodiff.listeners.BaseListener;
|
||||||
|
@ -113,16 +114,16 @@ public class ExecDebuggingListener extends BaseListener {
|
||||||
if(co.tArgs() != null && co.tArgs().length > 0) {
|
if(co.tArgs() != null && co.tArgs().length > 0) {
|
||||||
sb.append("\n\ttArgs=").append(Arrays.toString(co.tArgs()));
|
sb.append("\n\ttArgs=").append(Arrays.toString(co.tArgs()));
|
||||||
}
|
}
|
||||||
INDArray[] inputs = co.inputArguments();
|
val inputs = co.inputArguments();
|
||||||
INDArray[] outputs = co.outputArguments();
|
val outputs = co.outputArguments();
|
||||||
if(inputs != null ) {
|
if(inputs != null ) {
|
||||||
for (int i = 0; i < inputs.length; i++) {
|
for (int i = 0; i < inputs.size(); i++) {
|
||||||
sb.append("\n\tInput[").append(i).append("]=").append(inputs[i].shapeInfoToString());
|
sb.append("\n\tInput[").append(i).append("]=").append(inputs.get(i).shapeInfoToString());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if(outputs != null ) {
|
if(outputs != null ) {
|
||||||
for (int i = 0; i < outputs.length; i++) {
|
for (int i = 0; i < outputs.size(); i++) {
|
||||||
sb.append("\n\tOutputs[").append(i).append("]=").append(outputs[i].shapeInfoToString());
|
sb.append("\n\tOutputs[").append(i).append("]=").append(outputs.get(i).shapeInfoToString());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -156,22 +157,22 @@ public class ExecDebuggingListener extends BaseListener {
|
||||||
if(co.tArgs() != null && co.tArgs().length > 0 ){
|
if(co.tArgs() != null && co.tArgs().length > 0 ){
|
||||||
sb.append("op.addTArgument(").append(Arrays.toString(co.tArgs()).replaceAll("[\\[\\]]", "")).append(");\n");
|
sb.append("op.addTArgument(").append(Arrays.toString(co.tArgs()).replaceAll("[\\[\\]]", "")).append(");\n");
|
||||||
}
|
}
|
||||||
INDArray[] inputs = co.inputArguments();
|
val inputs = co.inputArguments();
|
||||||
INDArray[] outputs = co.outputArguments();
|
val outputs = co.outputArguments();
|
||||||
if(inputs != null ) {
|
if(inputs != null ) {
|
||||||
sb.append("INDArray[] inputs = new INDArray[").append(inputs.length).append("];\n");
|
sb.append("INDArray[] inputs = new INDArray[").append(inputs.size()).append("];\n");
|
||||||
for (int i = 0; i < inputs.length; i++) {
|
for (int i = 0; i < inputs.size(); i++) {
|
||||||
sb.append("inputs[").append(i).append("] = ");
|
sb.append("inputs[").append(i).append("] = ");
|
||||||
sb.append(createString(inputs[i]))
|
sb.append(createString(inputs.get(i)))
|
||||||
.append(";\n");
|
.append(";\n");
|
||||||
}
|
}
|
||||||
sb.append("op.addInputArgument(inputs);\n");
|
sb.append("op.addInputArgument(inputs);\n");
|
||||||
}
|
}
|
||||||
if(outputs != null ) {
|
if(outputs != null ) {
|
||||||
sb.append("INDArray[] outputs = new INDArray[").append(outputs.length).append("];\n");
|
sb.append("INDArray[] outputs = new INDArray[").append(outputs.size()).append("];\n");
|
||||||
for (int i = 0; i < outputs.length; i++) {
|
for (int i = 0; i < outputs.size(); i++) {
|
||||||
sb.append("outputs[").append(i).append("] = ");
|
sb.append("outputs[").append(i).append("] = ");
|
||||||
sb.append(createString(outputs[i]))
|
sb.append(createString(outputs.get(i)))
|
||||||
.append(";\n");
|
.append(";\n");
|
||||||
}
|
}
|
||||||
sb.append("op.addOutputArgument(outputs);\n");
|
sb.append("op.addOutputArgument(outputs);\n");
|
||||||
|
|
|
@ -478,11 +478,11 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
||||||
}
|
}
|
||||||
throw new IllegalStateException(s);
|
throw new IllegalStateException(s);
|
||||||
}
|
}
|
||||||
return ((Assert) op).outputArguments();
|
return ((Assert) op).outputArguments().toArray(new INDArray[0]);
|
||||||
} else if (op instanceof CustomOp) {
|
} else if (op instanceof CustomOp) {
|
||||||
CustomOp c = (CustomOp) op;
|
CustomOp c = (CustomOp) op;
|
||||||
Nd4j.exec(c);
|
Nd4j.exec(c);
|
||||||
return c.outputArguments();
|
return c.outputArguments().toArray(new INDArray[0]);
|
||||||
} else if (op instanceof Op) {
|
} else if (op instanceof Op) {
|
||||||
Op o = (Op) op;
|
Op o = (Op) op;
|
||||||
Nd4j.exec(o);
|
Nd4j.exec(o);
|
||||||
|
|
|
@ -457,7 +457,7 @@ public class OpValidation {
|
||||||
for (int i = 0; i < testCase.testFns().size(); i++) {
|
for (int i = 0; i < testCase.testFns().size(); i++) {
|
||||||
String error;
|
String error;
|
||||||
try {
|
try {
|
||||||
error = testCase.testFns().get(i).apply(testCase.op().outputArguments()[i]);
|
error = testCase.testFns().get(i).apply(testCase.op().outputArguments().get(i));
|
||||||
} catch (Throwable t) {
|
} catch (Throwable t) {
|
||||||
throw new IllegalStateException("Exception thrown during op output validation for output " + i, t);
|
throw new IllegalStateException("Exception thrown during op output validation for output " + i, t);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package org.nd4j.autodiff.validation.listeners;
|
package org.nd4j.autodiff.validation.listeners;
|
||||||
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
import lombok.val;
|
||||||
import org.nd4j.autodiff.listeners.At;
|
import org.nd4j.autodiff.listeners.At;
|
||||||
import org.nd4j.autodiff.listeners.BaseListener;
|
import org.nd4j.autodiff.listeners.BaseListener;
|
||||||
import org.nd4j.autodiff.listeners.Operation;
|
import org.nd4j.autodiff.listeners.Operation;
|
||||||
|
@ -50,12 +51,12 @@ public class NonInplaceValidationListener extends BaseListener {
|
||||||
opInputs = new INDArray[]{o.x().dup(), o.y().dup()};
|
opInputs = new INDArray[]{o.x().dup(), o.y().dup()};
|
||||||
}
|
}
|
||||||
} else if(op.getOp() instanceof DynamicCustomOp){
|
} else if(op.getOp() instanceof DynamicCustomOp){
|
||||||
INDArray[] arr = ((DynamicCustomOp) op.getOp()).inputArguments();
|
val arr = ((DynamicCustomOp) op.getOp()).inputArguments();
|
||||||
opInputs = new INDArray[arr.length];
|
opInputs = new INDArray[arr.size()];
|
||||||
opInputsOrig = new INDArray[arr.length];
|
opInputsOrig = new INDArray[arr.size()];
|
||||||
for( int i=0; i<arr.length; i++ ){
|
for( int i=0; i<arr.size(); i++ ){
|
||||||
opInputsOrig[i] = arr[i];
|
opInputsOrig[i] = arr.get(i);
|
||||||
opInputs[i] = arr[i].dup();
|
opInputs[i] = arr.get(i).dup();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
throw new IllegalStateException("Unknown op type: " + op.getOp().getClass());
|
throw new IllegalStateException("Unknown op type: " + op.getOp().getClass());
|
||||||
|
|
|
@ -589,6 +589,10 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.random.impl.Range.class,
|
org.nd4j.linalg.api.ops.random.impl.Range.class,
|
||||||
org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution.class,
|
org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution.class,
|
||||||
org.nd4j.linalg.api.ops.random.impl.UniformDistribution.class,
|
org.nd4j.linalg.api.ops.random.impl.UniformDistribution.class,
|
||||||
|
org.nd4j.linalg.api.ops.util.PrintAffinity.class,
|
||||||
|
org.nd4j.linalg.api.ops.util.PrintVariable.class,
|
||||||
|
org.nd4j.linalg.api.ops.compat.CompatSparseToDense.class,
|
||||||
|
org.nd4j.linalg.api.ops.compat.CompatStringSplit.class,
|
||||||
org.nd4j.linalg.api.ops.custom.AdjustContrast.class,
|
org.nd4j.linalg.api.ops.custom.AdjustContrast.class,
|
||||||
org.nd4j.linalg.api.ops.custom.AdjustContrastV2.class,
|
org.nd4j.linalg.api.ops.custom.AdjustContrastV2.class,
|
||||||
org.nd4j.linalg.api.ops.custom.HsvToRgb.class,
|
org.nd4j.linalg.api.ops.custom.HsvToRgb.class,
|
||||||
|
|
|
@ -73,7 +73,7 @@ public class ActivationPReLU extends BaseActivationFunction {
|
||||||
preluBp.addIntegerArguments(axis);
|
preluBp.addIntegerArguments(axis);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Nd4j.getExecutioner().execAndReturn(preluBp.build());
|
Nd4j.exec(preluBp.build());
|
||||||
in.assign(outTemp);
|
in.assign(outTemp);
|
||||||
return new Pair<>(in, dLdalpha);
|
return new Pair<>(in, dLdalpha);
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,7 +23,6 @@ import com.google.flatbuffers.FlatBufferBuilder;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import net.ericaro.neoitertools.Generator;
|
|
||||||
import org.apache.commons.math3.util.FastMath;
|
import org.apache.commons.math3.util.FastMath;
|
||||||
import org.bytedeco.javacpp.BytePointer;
|
import org.bytedeco.javacpp.BytePointer;
|
||||||
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
|
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
|
||||||
|
@ -998,14 +997,14 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Pair<DataBuffer, DataBuffer> tadInfo =
|
Pair<DataBuffer, DataBuffer> tadInfo = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(this, dimension);
|
||||||
Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(this, dimension);
|
|
||||||
DataBuffer shapeInfo = tadInfo.getFirst();
|
DataBuffer shapeInfo = tadInfo.getFirst();
|
||||||
val shape = Shape.shape(shapeInfo);
|
val jShapeInfo = shapeInfo.asLong();
|
||||||
val stride = Shape.stride(shapeInfo).asLong();
|
val shape = Shape.shape(jShapeInfo);
|
||||||
|
val stride = Shape.stride(jShapeInfo);
|
||||||
long offset = offset() + tadInfo.getSecond().getLong(index);
|
long offset = offset() + tadInfo.getSecond().getLong(index);
|
||||||
val ews = shapeInfo.getLong(shapeInfo.getLong(0) * 2 + 2);
|
val ews = shapeInfo.getLong(jShapeInfo[0] * 2 + 2);
|
||||||
char tadOrder = (char) shapeInfo.getInt(shapeInfo.getLong(0) * 2 + 3);
|
char tadOrder = (char) shapeInfo.getInt(jShapeInfo[0] * 2 + 3);
|
||||||
val toTad = Nd4j.create(data(), shape, stride, offset, ews, tadOrder);
|
val toTad = Nd4j.create(data(), shape, stride, offset, ews, tadOrder);
|
||||||
return toTad;
|
return toTad;
|
||||||
}
|
}
|
||||||
|
@ -2217,9 +2216,10 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
if(isEmpty() || isS())
|
if(isEmpty() || isS())
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
return Shape.offset(jvmShapeInfo.javaShapeInformation) > 0
|
val c2 = (length() < data().length() && data.dataType() != DataType.INT);
|
||||||
|| (length() < data().length() && data.dataType() != DataType.INT)
|
val c3 = (data().originalDataBuffer() != null && data != data.originalDataBuffer());
|
||||||
|| data().originalDataBuffer() != null;
|
|
||||||
|
return c2 || c3;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -3585,6 +3585,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
case DOUBLE:
|
case DOUBLE:
|
||||||
case FLOAT:
|
case FLOAT:
|
||||||
case HALF:
|
case HALF:
|
||||||
|
case BFLOAT16:
|
||||||
return getDouble(i);
|
return getDouble(i);
|
||||||
case LONG:
|
case LONG:
|
||||||
case INT:
|
case INT:
|
||||||
|
@ -3592,6 +3593,9 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
case UBYTE:
|
case UBYTE:
|
||||||
case BYTE:
|
case BYTE:
|
||||||
case BOOL:
|
case BOOL:
|
||||||
|
case UINT64:
|
||||||
|
case UINT32:
|
||||||
|
case UINT16:
|
||||||
return getLong(i);
|
return getLong(i);
|
||||||
case UTF8:
|
case UTF8:
|
||||||
case COMPRESSED:
|
case COMPRESSED:
|
||||||
|
@ -4350,29 +4354,30 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
|
|
||||||
//epsilon equals
|
//epsilon equals
|
||||||
if (isScalar() && n.isScalar()) {
|
if (isScalar() && n.isScalar()) {
|
||||||
if (data.dataType() == DataType.FLOAT) {
|
if (isZ()) {
|
||||||
double val = getDouble(0);
|
val val = getLong(0);
|
||||||
double val2 = n.getDouble(0);
|
val val2 = n.getLong(0);
|
||||||
|
|
||||||
|
return val == val2;
|
||||||
|
} else if (isR()) {
|
||||||
|
val val = getDouble(0);
|
||||||
|
val val2 = n.getDouble(0);
|
||||||
|
|
||||||
if (Double.isNaN(val) != Double.isNaN(val2))
|
if (Double.isNaN(val) != Double.isNaN(val2))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
return Math.abs(val - val2) < eps;
|
return Math.abs(val - val2) < eps;
|
||||||
} else {
|
} else if (isB()) {
|
||||||
double val = getDouble(0);
|
val val = getInt(0);
|
||||||
double val2 = n.getDouble(0);
|
val val2 = n.getInt(0);
|
||||||
|
|
||||||
if (Double.isNaN(val) != Double.isNaN(val2))
|
return val == val2;
|
||||||
return false;
|
|
||||||
|
|
||||||
return Math.abs(val - val2) < eps;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if (isVector() && n.isVector()) {
|
} else if (isVector() && n.isVector()) {
|
||||||
|
val op = new EqualsWithEps(this, n, eps);
|
||||||
EqualsWithEps op = new EqualsWithEps(this, n, eps);
|
Nd4j.exec(op);
|
||||||
Nd4j.getExecutioner().exec(op);
|
val diff = op.z().getDouble(0);
|
||||||
double diff = op.z().getDouble(0);
|
|
||||||
|
|
||||||
return diff < 0.5;
|
return diff < 0.5;
|
||||||
}
|
}
|
||||||
|
@ -4750,8 +4755,8 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
return this;
|
return this;
|
||||||
|
|
||||||
checkArrangeArray(rearrange);
|
checkArrangeArray(rearrange);
|
||||||
int[] newShape = doPermuteSwap(shapeOf(), rearrange);
|
val newShape = doPermuteSwap(shape(), rearrange);
|
||||||
int[] newStride = doPermuteSwap(strideOf(), rearrange);
|
val newStride = doPermuteSwap(stride(), rearrange);
|
||||||
|
|
||||||
char newOrder = Shape.getOrder(newShape, newStride, 1);
|
char newOrder = Shape.getOrder(newShape, newStride, 1);
|
||||||
|
|
||||||
|
@ -4777,23 +4782,11 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
return this;
|
return this;
|
||||||
|
|
||||||
checkArrangeArray(rearrange);
|
checkArrangeArray(rearrange);
|
||||||
val newShape = doPermuteSwap(Shape.shapeOf(shapeInfo), rearrange);
|
val newShape = doPermuteSwap(shape(), rearrange);
|
||||||
val newStride = doPermuteSwap(Shape.stride(shapeInfo), rearrange);
|
val newStride = doPermuteSwap(stride(), rearrange);
|
||||||
char newOrder = Shape.getOrder(newShape, newStride, 1);
|
char newOrder = Shape.getOrder(newShape, newStride, 1);
|
||||||
|
|
||||||
//Set the shape information of this array: shape, stride, order.
|
|
||||||
//Shape info buffer: [rank, [shape], [stride], offset, elementwiseStride, order]
|
|
||||||
/*for( int i=0; i<rank; i++ ){
|
|
||||||
shapeInfo.put(1+i,newShape[i]);
|
|
||||||
shapeInfo.put(1+i+rank,newStride[i]);
|
|
||||||
}
|
|
||||||
shapeInfo.put(3+2*rank,newOrder);
|
|
||||||
*/
|
|
||||||
val ews = shapeInfo.get(2 * rank + 2);
|
val ews = shapeInfo.get(2 * rank + 2);
|
||||||
/*
|
|
||||||
if (ews < 1 && !attemptedToFindElementWiseStride)
|
|
||||||
throw new RuntimeException("EWS is -1");
|
|
||||||
*/
|
|
||||||
|
|
||||||
val si = Nd4j.getShapeInfoProvider().createShapeInformation(newShape, newStride, ews, newOrder, dataType(), isEmpty());
|
val si = Nd4j.getShapeInfoProvider().createShapeInformation(newShape, newStride, ews, newOrder, dataType(), isEmpty());
|
||||||
setShapeInformation(si);
|
setShapeInformation(si);
|
||||||
|
@ -4813,6 +4806,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
protected long[] doPermuteSwap(LongBuffer shape, int[] rearrange) {
|
protected long[] doPermuteSwap(LongBuffer shape, int[] rearrange) {
|
||||||
val ret = new long[rearrange.length];
|
val ret = new long[rearrange.length];
|
||||||
for (int i = 0; i < rearrange.length; i++) {
|
for (int i = 0; i < rearrange.length; i++) {
|
||||||
|
@ -4821,6 +4815,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
protected int[] doPermuteSwap(IntBuffer shape, int[] rearrange) {
|
protected int[] doPermuteSwap(IntBuffer shape, int[] rearrange) {
|
||||||
int[] ret = new int[rearrange.length];
|
int[] ret = new int[rearrange.length];
|
||||||
for (int i = 0; i < rearrange.length; i++) {
|
for (int i = 0; i < rearrange.length; i++) {
|
||||||
|
@ -4829,11 +4824,20 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
protected int[] doPermuteSwap(DataBuffer shape, int[] rearrange) {
|
protected int[] doPermuteSwap(DataBuffer shape, int[] rearrange) {
|
||||||
int[] ret = new int[rearrange.length];
|
int[] ret = new int[rearrange.length];
|
||||||
for (int i = 0; i < rearrange.length; i++) {
|
for (int i = 0; i < rearrange.length; i++) {
|
||||||
ret[i] = shape.getInt(rearrange[i]);
|
ret[i] = shape.getInt(rearrange[i]);
|
||||||
}
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected long[] doPermuteSwap(long[] shape, int[] rearrange) {
|
||||||
|
val ret = new long[rearrange.length];
|
||||||
|
for (int i = 0; i < rearrange.length; i++) {
|
||||||
|
ret[i] = shape[rearrange[i]];
|
||||||
|
}
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
@ -5413,29 +5417,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected int stringBuffer(FlatBufferBuilder builder, DataBuffer buffer) {
|
protected abstract int stringBuffer(FlatBufferBuilder builder, DataBuffer buffer);
|
||||||
Preconditions.checkArgument(buffer.dataType() == DataType.UTF8, "This method can be called on UTF8 buffers only");
|
|
||||||
try {
|
|
||||||
ByteArrayOutputStream bos = new ByteArrayOutputStream();
|
|
||||||
DataOutputStream dos = new DataOutputStream(bos);
|
|
||||||
|
|
||||||
val numWords = this.length();
|
|
||||||
val ub = (Utf8Buffer) buffer;
|
|
||||||
// writing length first
|
|
||||||
val t = length();
|
|
||||||
val ptr = (BytePointer) ub.pointer();
|
|
||||||
|
|
||||||
// now write all strings as bytes
|
|
||||||
for (int i = 0; i < ub.length(); i++) {
|
|
||||||
dos.writeByte(ptr.get(i));
|
|
||||||
}
|
|
||||||
|
|
||||||
val bytes = bos.toByteArray();
|
|
||||||
return FlatArray.createBufferVector(builder, bytes);
|
|
||||||
} catch (IOException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int toFlatArray(FlatBufferBuilder builder) {
|
public int toFlatArray(FlatBufferBuilder builder) {
|
||||||
|
@ -5543,13 +5525,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
return !any();
|
return !any();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getString(long index) {
|
|
||||||
if (!isS())
|
|
||||||
throw new UnsupportedOperationException("This method is usable only on String dataType, but got [" + this.dataType() + "]");
|
|
||||||
|
|
||||||
return ((Utf8Buffer) data).getString(index);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Validate that the operation is being applied on a numerical array (not boolean or utf8).
|
* Validate that the operation is being applied on a numerical array (not boolean or utf8).
|
||||||
|
|
|
@ -47,12 +47,9 @@ public interface CustomOp {
|
||||||
*/
|
*/
|
||||||
boolean isInplaceCall();
|
boolean isInplaceCall();
|
||||||
|
|
||||||
|
List<INDArray> outputArguments();
|
||||||
|
|
||||||
|
List<INDArray> inputArguments();
|
||||||
|
|
||||||
INDArray[] outputArguments();
|
|
||||||
|
|
||||||
INDArray[] inputArguments();
|
|
||||||
|
|
||||||
long[] iArgs();
|
long[] iArgs();
|
||||||
|
|
||||||
|
|
|
@ -261,19 +261,13 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray[] outputArguments() {
|
public List<INDArray> outputArguments() {
|
||||||
if (!outputArguments.isEmpty()) {
|
return outputArguments;
|
||||||
return outputArguments.toArray(new INDArray[0]);
|
|
||||||
}
|
|
||||||
return new INDArray[0];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray[] inputArguments() {
|
public List<INDArray> inputArguments() {
|
||||||
if (!inputArguments.isEmpty())
|
return inputArguments;
|
||||||
return inputArguments.toArray(new INDArray[0]);
|
|
||||||
return new INDArray[0];
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -367,10 +361,10 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
||||||
for (int i = 0; i < args.length; i++) {
|
for (int i = 0; i < args.length; i++) {
|
||||||
|
|
||||||
// it's possible to get into situation where number of args > number of arrays AT THIS MOMENT
|
// it's possible to get into situation where number of args > number of arrays AT THIS MOMENT
|
||||||
if (i >= arrsSoFar.length)
|
if (i >= arrsSoFar.size())
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
if (!Arrays.equals(args[i].getShape(), arrsSoFar[i].shape()))
|
if (!Arrays.equals(args[i].getShape(), arrsSoFar.get(i).shape()))
|
||||||
throw new ND4JIllegalStateException("Illegal array passed in as argument [" + i + "]. Expected shape " + Arrays.toString(args[i].getShape()) + " and received array with shape " + Arrays.toString(arg[i].shape()));
|
throw new ND4JIllegalStateException("Illegal array passed in as argument [" + i + "]. Expected shape " + Arrays.toString(args[i].getShape()) + " and received array with shape " + Arrays.toString(arg[i].shape()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,51 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.compat;
|
||||||
|
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This is a wrapper for SparseToDense op that impelements corresponding TF operation
|
||||||
|
*
|
||||||
|
* @author raver119@gmail.com
|
||||||
|
*/
|
||||||
|
public class CompatSparseToDense extends DynamicCustomOp {
|
||||||
|
|
||||||
|
public CompatSparseToDense() {
|
||||||
|
//
|
||||||
|
}
|
||||||
|
|
||||||
|
public CompatSparseToDense(INDArray indices, INDArray shape, INDArray values) {
|
||||||
|
Preconditions.checkArgument(shape.isZ() && indices.isZ(), "Shape & indices arrays must have one integer data types");
|
||||||
|
inputArguments.add(indices);
|
||||||
|
inputArguments.add(shape);
|
||||||
|
inputArguments.add(values);
|
||||||
|
}
|
||||||
|
|
||||||
|
public CompatSparseToDense(INDArray indices, INDArray shape, INDArray values, INDArray defaultVaule) {
|
||||||
|
this(indices, shape, values);
|
||||||
|
Preconditions.checkArgument(defaultVaule.dataType() == values.dataType(), "Values array must have the same data type as defaultValue array");
|
||||||
|
inputArguments.add(defaultVaule);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "compat_sparse_to_dense";
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,51 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.compat;
|
||||||
|
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This is a wrapper for StringSplit op that impelements corresponding TF operation
|
||||||
|
*
|
||||||
|
* @author raver119@gmail.com
|
||||||
|
*/
|
||||||
|
public class CompatStringSplit extends DynamicCustomOp {
|
||||||
|
|
||||||
|
public CompatStringSplit() {
|
||||||
|
//
|
||||||
|
}
|
||||||
|
|
||||||
|
public CompatStringSplit(INDArray strings, INDArray delimiter) {
|
||||||
|
Preconditions.checkArgument(strings.isS() && delimiter.isS(), "Input arrays must have one of UTF types");
|
||||||
|
inputArguments.add(strings);
|
||||||
|
inputArguments.add(delimiter);
|
||||||
|
}
|
||||||
|
|
||||||
|
public CompatStringSplit(INDArray strings, INDArray delimiter, INDArray indices, INDArray values) {
|
||||||
|
this(strings, delimiter);
|
||||||
|
|
||||||
|
outputArguments.add(indices);
|
||||||
|
outputArguments.add(values);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "compat_string_split";
|
||||||
|
}
|
||||||
|
}
|
|
@ -107,12 +107,12 @@ public class ScatterUpdate implements CustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray[] outputArguments() {
|
public List<INDArray> outputArguments() {
|
||||||
return op.outputArguments();
|
return op.outputArguments();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray[] inputArguments() {
|
public List<INDArray> inputArguments() {
|
||||||
return op.inputArguments();
|
return op.inputArguments();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,6 @@ import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.buffer.Utf8Buffer;
|
|
||||||
import org.nd4j.linalg.api.environment.Nd4jEnvironment;
|
import org.nd4j.linalg.api.environment.Nd4jEnvironment;
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -172,7 +171,7 @@ public class DefaultOpExecutioner implements OpExecutioner {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray[] exec(CustomOp op) {
|
public INDArray[] exec(CustomOp op) {
|
||||||
return execAndReturn(op).outputArguments();
|
return execAndReturn(op).outputArguments().toArray(new INDArray[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -822,7 +821,7 @@ public class DefaultOpExecutioner implements OpExecutioner {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String getString(Utf8Buffer buffer, long index) {
|
public String getString(DataBuffer buffer, long index) {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,6 @@ import lombok.NonNull;
|
||||||
import org.bytedeco.javacpp.Pointer;
|
import org.bytedeco.javacpp.Pointer;
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.buffer.Utf8Buffer;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArrayStatistics;
|
import org.nd4j.linalg.api.ndarray.INDArrayStatistics;
|
||||||
import org.nd4j.linalg.api.ops.*;
|
import org.nd4j.linalg.api.ops.*;
|
||||||
|
@ -32,8 +31,6 @@ import org.nd4j.linalg.api.rng.Random;
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
import org.nd4j.linalg.api.shape.TadPack;
|
import org.nd4j.linalg.api.shape.TadPack;
|
||||||
import org.nd4j.linalg.cache.TADManager;
|
import org.nd4j.linalg.cache.TADManager;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
|
||||||
import org.nd4j.linalg.profiler.OpProfiler;
|
|
||||||
import org.nd4j.linalg.profiler.ProfilerConfig;
|
import org.nd4j.linalg.profiler.ProfilerConfig;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -411,7 +408,7 @@ public interface OpExecutioner {
|
||||||
* @param index
|
* @param index
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
String getString(Utf8Buffer buffer, long index);
|
String getString(DataBuffer buffer, long index);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Temporary hook
|
* Temporary hook
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.util;
|
||||||
|
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This is a wrapper for PrintAffinity op that just prints out affinity & locality status of INDArray
|
||||||
|
*
|
||||||
|
* @author raver119@gmail.com
|
||||||
|
*/
|
||||||
|
public class PrintAffinity extends DynamicCustomOp {
|
||||||
|
|
||||||
|
public PrintAffinity() {
|
||||||
|
//
|
||||||
|
}
|
||||||
|
|
||||||
|
public PrintAffinity(INDArray array) {
|
||||||
|
inputArguments.add(array);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "print_affinity";
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,66 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.util;
|
||||||
|
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This is a wrapper for PrintVariable op that just prints out Variable to the stdout
|
||||||
|
*
|
||||||
|
* @author raver119@gmail.com
|
||||||
|
*/
|
||||||
|
public class PrintVariable extends DynamicCustomOp {
|
||||||
|
|
||||||
|
public PrintVariable() {
|
||||||
|
//
|
||||||
|
}
|
||||||
|
|
||||||
|
public PrintVariable(INDArray array, boolean printSpecial) {
|
||||||
|
inputArguments.add(array);
|
||||||
|
bArguments.add(printSpecial);
|
||||||
|
}
|
||||||
|
|
||||||
|
public PrintVariable(INDArray array) {
|
||||||
|
this(array, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
public PrintVariable(INDArray array, String message, boolean printSpecial) {
|
||||||
|
this(array, Nd4j.create(message), printSpecial);
|
||||||
|
}
|
||||||
|
|
||||||
|
public PrintVariable(INDArray array, String message) {
|
||||||
|
this(array, Nd4j.create(message), false);
|
||||||
|
}
|
||||||
|
|
||||||
|
public PrintVariable(INDArray array, INDArray message, boolean printSpecial) {
|
||||||
|
this(array, printSpecial);
|
||||||
|
Preconditions.checkArgument(message.isS(), "Message argument should have String data type, but got [" + message.dataType() +"] instead");
|
||||||
|
inputArguments.add(message);
|
||||||
|
}
|
||||||
|
|
||||||
|
public PrintVariable(INDArray array, INDArray message) {
|
||||||
|
this(array, message, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "print_variable";
|
||||||
|
}
|
||||||
|
}
|
|
@ -89,6 +89,11 @@ public class CompressedDataBuffer extends BaseDataBuffer {
|
||||||
// no-op
|
// no-op
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Pointer addressPointer() {
|
||||||
|
return pointer;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Drop-in replacement wrapper for BaseDataBuffer.read() method, aware of CompressedDataBuffer
|
* Drop-in replacement wrapper for BaseDataBuffer.read() method, aware of CompressedDataBuffer
|
||||||
* @param s
|
* @param s
|
||||||
|
@ -194,6 +199,15 @@ public class CompressedDataBuffer extends BaseDataBuffer {
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public DataBuffer create(int[] data) {
|
public DataBuffer create(int[] data) {
|
||||||
throw new UnsupportedOperationException("This operation isn't supported for CompressedDataBuffer");
|
throw new UnsupportedOperationException("This method isn't supported by CompressedDataBuffer");
|
||||||
|
}
|
||||||
|
|
||||||
|
public void pointerIndexerByCurrentType(DataType currentType) {
|
||||||
|
throw new UnsupportedOperationException("This method isn't supported by CompressedDataBuffer");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataBuffer reallocate(long length) {
|
||||||
|
throw new UnsupportedOperationException("This method isn't supported by CompressedDataBuffer");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -98,7 +98,7 @@ public class Convolution {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
Nd4j.getExecutioner().execAndReturn(col2Im);
|
Nd4j.getExecutioner().execAndReturn(col2Im);
|
||||||
return col2Im.outputArguments()[0];
|
return col2Im.outputArguments().get(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static INDArray col2im(INDArray col, INDArray z, int sH, int sW, int pH, int pW, int kH, int kW,
|
public static INDArray col2im(INDArray col, INDArray z, int sH, int sW, int pH, int pW, int kH, int kW,
|
||||||
|
@ -187,7 +187,7 @@ public class Convolution {
|
||||||
.build()).build();
|
.build()).build();
|
||||||
|
|
||||||
Nd4j.getExecutioner().execAndReturn(im2col);
|
Nd4j.getExecutioner().execAndReturn(im2col);
|
||||||
return im2col.outputArguments()[0];
|
return im2col.outputArguments().get(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static INDArray im2col(INDArray img, int kh, int kw, int sy, int sx, int ph, int pw, int dH, int dW, boolean isSameMode,
|
public static INDArray im2col(INDArray img, int kh, int kw, int sy, int sx, int ph, int pw, int dH, int dW, boolean isSameMode,
|
||||||
|
@ -208,7 +208,7 @@ public class Convolution {
|
||||||
.build()).build();
|
.build()).build();
|
||||||
|
|
||||||
Nd4j.getExecutioner().execAndReturn(im2col);
|
Nd4j.getExecutioner().execAndReturn(im2col);
|
||||||
return im2col.outputArguments()[0];
|
return im2col.outputArguments().get(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -298,7 +298,7 @@ public class Convolution {
|
||||||
.build()).build();
|
.build()).build();
|
||||||
|
|
||||||
Nd4j.getExecutioner().execAndReturn(im2col);
|
Nd4j.getExecutioner().execAndReturn(im2col);
|
||||||
return im2col.outputArguments()[0];
|
return im2col.outputArguments().get(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -40,7 +40,6 @@ import org.nd4j.graph.FlatArray;
|
||||||
import org.nd4j.linalg.api.blas.params.MMulTranspose;
|
import org.nd4j.linalg.api.blas.params.MMulTranspose;
|
||||||
import org.nd4j.linalg.api.buffer.*;
|
import org.nd4j.linalg.api.buffer.*;
|
||||||
import org.nd4j.linalg.api.buffer.factory.DataBufferFactory;
|
import org.nd4j.linalg.api.buffer.factory.DataBufferFactory;
|
||||||
import org.nd4j.linalg.api.buffer.factory.DefaultDataBufferFactory;
|
|
||||||
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
|
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
|
||||||
import org.nd4j.linalg.api.concurrency.AffinityManager;
|
import org.nd4j.linalg.api.concurrency.AffinityManager;
|
||||||
import org.nd4j.linalg.api.concurrency.BasicAffinityManager;
|
import org.nd4j.linalg.api.concurrency.BasicAffinityManager;
|
||||||
|
@ -1044,16 +1043,7 @@ public class Nd4j {
|
||||||
* @return the created buffer
|
* @return the created buffer
|
||||||
*/
|
*/
|
||||||
public static DataBuffer createBuffer(ByteBuffer buffer, DataType type, int length, long offset) {
|
public static DataBuffer createBuffer(ByteBuffer buffer, DataType type, int length, long offset) {
|
||||||
switch (type) {
|
return DATA_BUFFER_FACTORY_INSTANCE.create(buffer, type, length, offset);
|
||||||
case INT:
|
|
||||||
return DATA_BUFFER_FACTORY_INSTANCE.createInt(offset, buffer, length);
|
|
||||||
case DOUBLE:
|
|
||||||
return DATA_BUFFER_FACTORY_INSTANCE.createDouble(offset, buffer, length);
|
|
||||||
case FLOAT:
|
|
||||||
return DATA_BUFFER_FACTORY_INSTANCE.createFloat(offset, buffer, length);
|
|
||||||
default:
|
|
||||||
throw new IllegalArgumentException("Illegal opType " + type);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -1336,38 +1326,9 @@ public class Nd4j {
|
||||||
* @return the created buffer
|
* @return the created buffer
|
||||||
*/
|
*/
|
||||||
public static DataBuffer createBuffer(ByteBuffer buffer, DataType type, int length) {
|
public static DataBuffer createBuffer(ByteBuffer buffer, DataType type, int length) {
|
||||||
switch (type) {
|
return createBuffer(buffer, type, length, 0);
|
||||||
case INT:
|
|
||||||
return DATA_BUFFER_FACTORY_INSTANCE.createInt(buffer, length);
|
|
||||||
case LONG:
|
|
||||||
return DATA_BUFFER_FACTORY_INSTANCE.createLong(buffer, length);
|
|
||||||
case DOUBLE:
|
|
||||||
return DATA_BUFFER_FACTORY_INSTANCE.createDouble(buffer, length);
|
|
||||||
case FLOAT:
|
|
||||||
return DATA_BUFFER_FACTORY_INSTANCE.createFloat(buffer, length);
|
|
||||||
case HALF:
|
|
||||||
return DATA_BUFFER_FACTORY_INSTANCE.createHalf(buffer, length);
|
|
||||||
default:
|
|
||||||
throw new IllegalArgumentException("Illegal opType " + type);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a buffer based on the data opType
|
|
||||||
*
|
|
||||||
* @param data the data to create the buffer with
|
|
||||||
* @return the created buffer
|
|
||||||
*/
|
|
||||||
public static DataBuffer createBuffer(byte[] data, int length) {
|
|
||||||
DataBuffer ret;
|
|
||||||
if (dataType() == DataType.DOUBLE)
|
|
||||||
ret = DATA_BUFFER_FACTORY_INSTANCE.createDouble(data, length);
|
|
||||||
else if (dataType() == DataType.HALF)
|
|
||||||
ret = DATA_BUFFER_FACTORY_INSTANCE.createHalf(data, length);
|
|
||||||
else
|
|
||||||
ret = DATA_BUFFER_FACTORY_INSTANCE.createFloat(data, length);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a buffer equal of length prod(shape)
|
* Create a buffer equal of length prod(shape)
|
||||||
|
@ -2206,6 +2167,7 @@ public class Nd4j {
|
||||||
private static String writeStringForArray(INDArray write) {
|
private static String writeStringForArray(INDArray write) {
|
||||||
if(write.isView() || !Shape.hasDefaultStridesForShape(write))
|
if(write.isView() || !Shape.hasDefaultStridesForShape(write))
|
||||||
write = write.dup();
|
write = write.dup();
|
||||||
|
|
||||||
String format = "0.000000000000000000E0";
|
String format = "0.000000000000000000E0";
|
||||||
|
|
||||||
return "{\n" +
|
return "{\n" +
|
||||||
|
@ -3927,16 +3889,6 @@ public class Nd4j {
|
||||||
return create(shape, stride);
|
return create(shape, stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates an ndarray with the specified shape
|
|
||||||
*
|
|
||||||
* @param rows the rows of the ndarray
|
|
||||||
* @param columns the columns of the ndarray
|
|
||||||
* @return the instance
|
|
||||||
*/
|
|
||||||
public static INDArray create(int rows, int columns) {
|
|
||||||
return create(rows, columns, order());
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates an ndarray with the specified shape
|
* Creates an ndarray with the specified shape
|
||||||
|
@ -4386,13 +4338,6 @@ public class Nd4j {
|
||||||
return createUninitialized(shape, Nd4j.order());
|
return createUninitialized(shape, Nd4j.order());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* See {@link #createUninitialized(long)}
|
|
||||||
*/
|
|
||||||
public static INDArray createUninitialized(int length) {
|
|
||||||
return createUninitialized((long)length);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method creates an *uninitialized* ndarray of specified length and default ordering.
|
* This method creates an *uninitialized* ndarray of specified length and default ordering.
|
||||||
*
|
*
|
||||||
|
@ -4428,37 +4373,6 @@ public class Nd4j {
|
||||||
|
|
||||||
////////////////////// OTHER ///////////////////////////////
|
////////////////////// OTHER ///////////////////////////////
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a 2D array with specified number of rows, columns initialized with zero.
|
|
||||||
*
|
|
||||||
* @param rows number of rows.
|
|
||||||
* @param columns number of columns.
|
|
||||||
* @return the created array.
|
|
||||||
*/
|
|
||||||
public static INDArray zeros(long rows, long columns) {
|
|
||||||
return INSTANCE.zeros(rows, columns);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a 1D array with the specified number of columns initialized with zero.
|
|
||||||
*
|
|
||||||
* @param columns number of columns.
|
|
||||||
* @return the created array
|
|
||||||
*/
|
|
||||||
public static INDArray zeros(int columns) {
|
|
||||||
return INSTANCE.zeros(columns);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a 1D array with the specified data tyoe and number of columns initialized with zero.
|
|
||||||
*
|
|
||||||
* @param dataType data type.
|
|
||||||
* @param columns number of columns.
|
|
||||||
* @return the created array.
|
|
||||||
*/
|
|
||||||
public static INDArray zeros(DataType dataType, int columns) {
|
|
||||||
return INSTANCE.create(dataType, new long[]{columns}, 'c', Nd4j.getMemoryManager().getCurrentWorkspace());
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates an array with the specified data tyoe and shape initialized with zero.
|
* Creates an array with the specified data tyoe and shape initialized with zero.
|
||||||
|
@ -4468,7 +4382,10 @@ public class Nd4j {
|
||||||
* @return the created array.
|
* @return the created array.
|
||||||
*/
|
*/
|
||||||
public static INDArray zeros(DataType dataType, @NonNull long... shape) {
|
public static INDArray zeros(DataType dataType, @NonNull long... shape) {
|
||||||
return INSTANCE.create(dataType, shape, 'c', Nd4j.getMemoryManager().getCurrentWorkspace());
|
if(shape.length == 0)
|
||||||
|
return Nd4j.scalar(dataType, 0);
|
||||||
|
|
||||||
|
return INSTANCE.create(dataType, shape, Nd4j.order(), Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -4588,31 +4505,6 @@ public class Nd4j {
|
||||||
return INSTANCE.valueArrayOf(rows, columns, value);
|
return INSTANCE.valueArrayOf(rows, columns, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a row vector with the specified number of columns
|
|
||||||
*
|
|
||||||
* @param rows the number of rows in the matrix
|
|
||||||
* @param columns the columns of the ndarray
|
|
||||||
* @return the created ndarray
|
|
||||||
*/
|
|
||||||
public static INDArray ones(int rows, int columns) {
|
|
||||||
return INSTANCE.ones(rows, columns);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a 2D array with the given rows, columns and data type initialised with ones.
|
|
||||||
*
|
|
||||||
* @param dataType data type
|
|
||||||
* @param rows rows of the new array.
|
|
||||||
* @param columns columns of the new arrau.
|
|
||||||
* @return the created array
|
|
||||||
*/
|
|
||||||
public static INDArray ones(DataType dataType, int rows, int columns) {
|
|
||||||
INDArray ret = INSTANCE.createUninitialized(dataType, new long[]{rows, columns}, Nd4j.order(), Nd4j.getMemoryManager().getCurrentWorkspace());
|
|
||||||
ret.assign(1);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Empty like
|
* Empty like
|
||||||
*
|
*
|
||||||
|
@ -4817,8 +4709,7 @@ public class Nd4j {
|
||||||
|
|
||||||
for (int idx : indexes) {
|
for (int idx : indexes) {
|
||||||
if (idx < 0 || idx >= source.shape()[source.rank() - sourceDimension - 1]) {
|
if (idx < 0 || idx >= source.shape()[source.rank() - sourceDimension - 1]) {
|
||||||
throw new IllegalStateException(
|
throw new IllegalStateException("Index can't be < 0 and >= " + source.shape()[source.rank() - sourceDimension - 1]);
|
||||||
"Index can't be < 0 and >= " + source.shape()[source.rank() - sourceDimension - 1]);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5186,7 +5077,7 @@ public class Nd4j {
|
||||||
pp.toString(NDARRAY_FACTORY_CLASS));
|
pp.toString(NDARRAY_FACTORY_CLASS));
|
||||||
Class<? extends ConvolutionInstance> convolutionInstanceClazz = (Class<? extends ConvolutionInstance>) Class
|
Class<? extends ConvolutionInstance> convolutionInstanceClazz = (Class<? extends ConvolutionInstance>) Class
|
||||||
.forName(pp.toString(CONVOLUTION_OPS, DefaultConvolutionInstance.class.getName()));
|
.forName(pp.toString(CONVOLUTION_OPS, DefaultConvolutionInstance.class.getName()));
|
||||||
String defaultName = pp.toString(DATA_BUFFER_OPS, DefaultDataBufferFactory.class.getName());
|
String defaultName = pp.toString(DATA_BUFFER_OPS, "org.nd4j.linalg.cpu.nativecpu.buffer.DefaultDataBufferFactory");
|
||||||
Class<? extends DataBufferFactory> dataBufferFactoryClazz = (Class<? extends DataBufferFactory>) Class
|
Class<? extends DataBufferFactory> dataBufferFactoryClazz = (Class<? extends DataBufferFactory>) Class
|
||||||
.forName(pp.toString(DATA_BUFFER_OPS, defaultName));
|
.forName(pp.toString(DATA_BUFFER_OPS, defaultName));
|
||||||
Class<? extends BaseShapeInfoProvider> shapeInfoProviderClazz = (Class<? extends BaseShapeInfoProvider>) Class
|
Class<? extends BaseShapeInfoProvider> shapeInfoProviderClazz = (Class<? extends BaseShapeInfoProvider>) Class
|
||||||
|
@ -5871,7 +5762,7 @@ public class Nd4j {
|
||||||
arr[e] = sb.get(e + pos);
|
arr[e] = sb.get(e + pos);
|
||||||
}
|
}
|
||||||
|
|
||||||
val buffer = new Utf8Buffer(arr, prod);
|
val buffer = DATA_BUFFER_FACTORY_INSTANCE.createUtf8Buffer(arr, prod);
|
||||||
return Nd4j.create(buffer, shapeOf);
|
return Nd4j.create(buffer, shapeOf);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
|
|
|
@ -30,6 +30,7 @@ import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
import java.util.concurrent.atomic.AtomicLong;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This class provides unified management for Deallocatable resources
|
* This class provides unified management for Deallocatable resources
|
||||||
|
@ -43,6 +44,8 @@ public class DeallocatorService {
|
||||||
private Map<String, DeallocatableReference> referenceMap = new ConcurrentHashMap<>();
|
private Map<String, DeallocatableReference> referenceMap = new ConcurrentHashMap<>();
|
||||||
private List<List<ReferenceQueue<Deallocatable>>> deviceMap = new ArrayList<>();
|
private List<List<ReferenceQueue<Deallocatable>>> deviceMap = new ArrayList<>();
|
||||||
|
|
||||||
|
private AtomicLong counter = new AtomicLong(0);
|
||||||
|
|
||||||
public DeallocatorService() {
|
public DeallocatorService() {
|
||||||
// we need to have at least 2 threads, but for CUDA we'd need at least numDevices threads, due to thread->device affinity
|
// we need to have at least 2 threads, but for CUDA we'd need at least numDevices threads, due to thread->device affinity
|
||||||
int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
|
int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
|
||||||
|
@ -69,6 +72,10 @@ public class DeallocatorService {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public long nextValue() {
|
||||||
|
return counter.incrementAndGet();
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method adds Deallocatable object instance to tracking system
|
* This method adds Deallocatable object instance to tracking system
|
||||||
*
|
*
|
||||||
|
|
|
@ -17,10 +17,10 @@
|
||||||
package org.nd4j.serde.jackson.shaded;
|
package org.nd4j.serde.jackson.shaded;
|
||||||
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.buffer.Utf8Buffer;
|
|
||||||
|
import lombok.val;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.shape.Shape;
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
import org.nd4j.serde.base64.Nd4jBase64;
|
|
||||||
import org.nd4j.shade.jackson.core.JsonGenerator;
|
import org.nd4j.shade.jackson.core.JsonGenerator;
|
||||||
import org.nd4j.shade.jackson.databind.JsonSerializer;
|
import org.nd4j.shade.jackson.databind.JsonSerializer;
|
||||||
import org.nd4j.shade.jackson.databind.SerializerProvider;
|
import org.nd4j.shade.jackson.databind.SerializerProvider;
|
||||||
|
@ -77,10 +77,9 @@ public class NDArrayTextSerializer extends JsonSerializer<INDArray> {
|
||||||
jg.writeNumber(v);
|
jg.writeNumber(v);
|
||||||
break;
|
break;
|
||||||
case UTF8:
|
case UTF8:
|
||||||
Utf8Buffer utf8B = ((Utf8Buffer)arr.data());
|
val n = arr.length();
|
||||||
long n = utf8B.getNumWords();
|
|
||||||
for( int j=0; j<n; j++ ) {
|
for( int j=0; j<n; j++ ) {
|
||||||
String s = utf8B.getString(j);
|
String s = arr.getString(j);
|
||||||
jg.writeString(s);
|
jg.writeString(s);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
|
@ -16,11 +16,8 @@
|
||||||
|
|
||||||
package org.nd4j.nativeblas;
|
package org.nd4j.nativeblas;
|
||||||
|
|
||||||
import lombok.val;
|
|
||||||
import org.bytedeco.javacpp.*;
|
import org.bytedeco.javacpp.*;
|
||||||
import org.bytedeco.javacpp.annotation.Cast;
|
import org.bytedeco.javacpp.annotation.Cast;
|
||||||
import org.bytedeco.javacpp.indexer.LongIndexer;
|
|
||||||
import org.nd4j.linalg.api.buffer.Utf8Buffer;
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -53,14 +50,12 @@ public interface NativeOps {
|
||||||
*/
|
*/
|
||||||
void execIndexReduceScalar(PointerPointer extraPointers,
|
void execIndexReduceScalar(PointerPointer extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Pointer x,
|
OpaqueDataBuffer x,
|
||||||
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
||||||
Pointer dX,
|
|
||||||
@Cast("Nd4jLong *") LongPointer dXShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dXShapeInfo,
|
||||||
Pointer extraParams,
|
Pointer extraParams,
|
||||||
Pointer z,
|
OpaqueDataBuffer z,
|
||||||
@Cast("Nd4jLong *") LongPointer zShapeInfo,
|
@Cast("Nd4jLong *") LongPointer zShapeInfo,
|
||||||
Pointer dZ,
|
|
||||||
@Cast("Nd4jLong *") LongPointer dZShapeInfo);
|
@Cast("Nd4jLong *") LongPointer dZShapeInfo);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -75,17 +70,16 @@ public interface NativeOps {
|
||||||
*/
|
*/
|
||||||
void execIndexReduce(PointerPointer extraPointers,
|
void execIndexReduce(PointerPointer extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Pointer x,
|
OpaqueDataBuffer x,
|
||||||
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
||||||
Pointer dX,
|
|
||||||
@Cast("Nd4jLong *") LongPointer dXShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dXShapeInfo,
|
||||||
Pointer extraParams,
|
Pointer extraParams,
|
||||||
Pointer result,
|
OpaqueDataBuffer result,
|
||||||
@Cast("Nd4jLong *") LongPointer resultShapeInfoBuffer,
|
@Cast("Nd4jLong *") LongPointer resultShapeInfoBuffer,
|
||||||
Pointer dResult,
|
|
||||||
@Cast("Nd4jLong *") LongPointer dResultShapeInfoBuffer,
|
@Cast("Nd4jLong *") LongPointer dResultShapeInfoBuffer,
|
||||||
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape,
|
OpaqueDataBuffer hDimension,
|
||||||
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape);
|
@Cast("Nd4jLong *") LongPointer hDimensionShape,
|
||||||
|
@Cast("Nd4jLong *") LongPointer dDimensionShape);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param opNum
|
* @param opNum
|
||||||
|
@ -100,38 +94,34 @@ public interface NativeOps {
|
||||||
*/
|
*/
|
||||||
void execBroadcast(PointerPointer extraPointers,
|
void execBroadcast(PointerPointer extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Pointer x,
|
OpaqueDataBuffer x,
|
||||||
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
||||||
Pointer dx,
|
|
||||||
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
||||||
Pointer y,
|
OpaqueDataBuffer y,
|
||||||
@Cast("Nd4jLong *") LongPointer yShapeInfo,
|
@Cast("Nd4jLong *") LongPointer yShapeInfo,
|
||||||
Pointer dy,
|
|
||||||
@Cast("Nd4jLong *") LongPointer dyShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dyShapeInfo,
|
||||||
Pointer result,
|
OpaqueDataBuffer result,
|
||||||
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
||||||
Pointer dresult,
|
|
||||||
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
||||||
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape,
|
OpaqueDataBuffer hDimension,
|
||||||
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape);
|
@Cast("Nd4jLong *") LongPointer hDimensionShape,
|
||||||
|
@Cast("Nd4jLong *") LongPointer dDimensionShape);
|
||||||
|
|
||||||
void execBroadcastBool(PointerPointer extraPointers,
|
void execBroadcastBool(PointerPointer extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Pointer x,
|
OpaqueDataBuffer x,
|
||||||
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
||||||
Pointer dx,
|
|
||||||
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
||||||
Pointer y,
|
OpaqueDataBuffer y,
|
||||||
@Cast("Nd4jLong *") LongPointer yShapeInfo,
|
@Cast("Nd4jLong *") LongPointer yShapeInfo,
|
||||||
Pointer dy,
|
|
||||||
@Cast("Nd4jLong *") LongPointer dyShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dyShapeInfo,
|
||||||
Pointer result,
|
OpaqueDataBuffer result,
|
||||||
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
||||||
Pointer dresult,
|
|
||||||
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
||||||
Pointer extraParams,
|
Pointer extraParams,
|
||||||
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape,
|
OpaqueDataBuffer hDimension,
|
||||||
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape);
|
@Cast("Nd4jLong *") LongPointer hDimensionShape,
|
||||||
|
@Cast("Nd4jLong *") LongPointer dDimensionShape);
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -146,33 +136,27 @@ public interface NativeOps {
|
||||||
*/
|
*/
|
||||||
void execPairwiseTransform(PointerPointer extraPointers,
|
void execPairwiseTransform(PointerPointer extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Pointer x,
|
OpaqueDataBuffer x,
|
||||||
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
||||||
Pointer dx,
|
|
||||||
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
||||||
Pointer y,
|
OpaqueDataBuffer y,
|
||||||
@Cast("Nd4jLong *") LongPointer yShapeInfo,
|
@Cast("Nd4jLong *") LongPointer yShapeInfo,
|
||||||
Pointer dy,
|
|
||||||
@Cast("Nd4jLong *") LongPointer dyShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dyShapeInfo,
|
||||||
Pointer result,
|
OpaqueDataBuffer result,
|
||||||
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
||||||
Pointer dresult,
|
|
||||||
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
||||||
Pointer extraParams);
|
Pointer extraParams);
|
||||||
|
|
||||||
void execPairwiseTransformBool(PointerPointer extraPointers,
|
void execPairwiseTransformBool(PointerPointer extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Pointer x,
|
OpaqueDataBuffer x,
|
||||||
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
||||||
Pointer dx,
|
|
||||||
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
||||||
Pointer y,
|
OpaqueDataBuffer y,
|
||||||
@Cast("Nd4jLong *") LongPointer yShapeInfo,
|
@Cast("Nd4jLong *") LongPointer yShapeInfo,
|
||||||
Pointer dy,
|
|
||||||
@Cast("Nd4jLong *") LongPointer dyShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dyShapeInfo,
|
||||||
Pointer result,
|
OpaqueDataBuffer result,
|
||||||
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
||||||
Pointer dresult,
|
|
||||||
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
||||||
Pointer extraParams);
|
Pointer extraParams);
|
||||||
|
|
||||||
|
@ -186,53 +170,45 @@ public interface NativeOps {
|
||||||
*/
|
*/
|
||||||
void execReduceFloat(PointerPointer extraPointers,
|
void execReduceFloat(PointerPointer extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Pointer x,
|
OpaqueDataBuffer x,
|
||||||
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
||||||
Pointer dx,
|
|
||||||
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
||||||
Pointer extraParams,
|
Pointer extraParams,
|
||||||
Pointer result,
|
OpaqueDataBuffer result,
|
||||||
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
||||||
Pointer dresult,
|
|
||||||
@Cast("Nd4jLong *") LongPointer dresultShapeInfo);
|
@Cast("Nd4jLong *") LongPointer dresultShapeInfo);
|
||||||
|
|
||||||
|
|
||||||
void execReduceSame(PointerPointer extraPointers,
|
void execReduceSame(PointerPointer extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Pointer x,
|
OpaqueDataBuffer x,
|
||||||
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
||||||
Pointer dx,
|
|
||||||
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
||||||
Pointer extraParams,
|
Pointer extraParams,
|
||||||
Pointer result,
|
OpaqueDataBuffer result,
|
||||||
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
||||||
Pointer dresult,
|
|
||||||
@Cast("Nd4jLong *") LongPointer dresultShapeInfo);
|
@Cast("Nd4jLong *") LongPointer dresultShapeInfo);
|
||||||
|
|
||||||
|
|
||||||
void execReduceBool(PointerPointer extraPointers,
|
void execReduceBool(PointerPointer extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Pointer x,
|
OpaqueDataBuffer x,
|
||||||
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
||||||
Pointer dx,
|
|
||||||
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
||||||
Pointer extraParams,
|
Pointer extraParams,
|
||||||
Pointer result,
|
OpaqueDataBuffer result,
|
||||||
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
||||||
Pointer dresult,
|
|
||||||
@Cast("Nd4jLong *") LongPointer dresultShapeInfo);
|
@Cast("Nd4jLong *") LongPointer dresultShapeInfo);
|
||||||
|
|
||||||
|
|
||||||
void execReduceLong(PointerPointer extraPointers,
|
void execReduceLong(PointerPointer extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Pointer x,
|
OpaqueDataBuffer x,
|
||||||
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
||||||
Pointer dx,
|
|
||||||
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
||||||
Pointer extraParams,
|
Pointer extraParams,
|
||||||
Pointer result,
|
OpaqueDataBuffer result,
|
||||||
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
||||||
Pointer dresult,
|
|
||||||
@Cast("Nd4jLong *") LongPointer dresultShapeInfo);
|
@Cast("Nd4jLong *") LongPointer dresultShapeInfo);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -245,60 +221,56 @@ public interface NativeOps {
|
||||||
*/
|
*/
|
||||||
void execReduceFloat2(PointerPointer extraPointers,
|
void execReduceFloat2(PointerPointer extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Pointer x,
|
OpaqueDataBuffer x,
|
||||||
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
||||||
Pointer dx,
|
|
||||||
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
||||||
Pointer extraParams,
|
Pointer extraParams,
|
||||||
Pointer result,
|
OpaqueDataBuffer result,
|
||||||
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
||||||
Pointer dresult,
|
|
||||||
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
||||||
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape,
|
OpaqueDataBuffer hDimension,
|
||||||
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape);
|
@Cast("Nd4jLong *") LongPointer hDimensionShape,
|
||||||
|
@Cast("Nd4jLong *") LongPointer dDimensionShape);
|
||||||
|
|
||||||
|
|
||||||
void execReduceSame2(PointerPointer extraPointers,
|
void execReduceSame2(PointerPointer extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Pointer x,
|
OpaqueDataBuffer x,
|
||||||
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
||||||
Pointer dx,
|
|
||||||
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
||||||
Pointer extraParams,
|
Pointer extraParams,
|
||||||
Pointer result,
|
OpaqueDataBuffer result,
|
||||||
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
||||||
Pointer dresult,
|
|
||||||
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
||||||
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape,
|
OpaqueDataBuffer hDimension,
|
||||||
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape);
|
@Cast("Nd4jLong *") LongPointer hDimensionShape,
|
||||||
|
@Cast("Nd4jLong *") LongPointer dDimensionShape);
|
||||||
|
|
||||||
void execReduceBool2(PointerPointer extraPointers,
|
void execReduceBool2(PointerPointer extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Pointer x,
|
OpaqueDataBuffer x,
|
||||||
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
||||||
Pointer dx,
|
|
||||||
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
||||||
Pointer extraParams,
|
Pointer extraParams,
|
||||||
Pointer result,
|
OpaqueDataBuffer result,
|
||||||
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
||||||
Pointer dresult,
|
|
||||||
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
||||||
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape,
|
OpaqueDataBuffer hDimension,
|
||||||
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape);
|
@Cast("Nd4jLong *") LongPointer hDimensionShape,
|
||||||
|
@Cast("Nd4jLong *") LongPointer dDimensionShape);
|
||||||
|
|
||||||
void execReduceLong2(PointerPointer extraPointers,
|
void execReduceLong2(PointerPointer extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Pointer x,
|
OpaqueDataBuffer x,
|
||||||
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
||||||
Pointer dx,
|
|
||||||
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
||||||
Pointer extraParams,
|
Pointer extraParams,
|
||||||
Pointer result,
|
OpaqueDataBuffer result,
|
||||||
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
||||||
Pointer dresult,
|
|
||||||
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
||||||
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape,
|
OpaqueDataBuffer hDimension,
|
||||||
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape);
|
@Cast("Nd4jLong *") LongPointer hDimensionShape,
|
||||||
|
@Cast("Nd4jLong *") LongPointer dDimensionShape);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param opNum
|
* @param opNum
|
||||||
|
@ -312,13 +284,16 @@ public interface NativeOps {
|
||||||
*/
|
*/
|
||||||
void execReduce3(PointerPointer extraPointers,
|
void execReduce3(PointerPointer extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo,
|
OpaqueDataBuffer x,
|
||||||
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
||||||
|
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
||||||
Pointer extraParamsVals,
|
Pointer extraParamsVals,
|
||||||
Pointer y, @Cast("Nd4jLong *") LongPointer yShapeInfo,
|
OpaqueDataBuffer y,
|
||||||
Pointer dy, @Cast("Nd4jLong *") LongPointer dyShapeInfo,
|
@Cast("Nd4jLong *") LongPointer yShapeInfo,
|
||||||
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dyShapeInfo,
|
||||||
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo);
|
OpaqueDataBuffer result,
|
||||||
|
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
||||||
|
@Cast("Nd4jLong *") LongPointer dresultShapeInfo);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param opNum
|
* @param opNum
|
||||||
|
@ -329,13 +304,16 @@ public interface NativeOps {
|
||||||
* @param yShapeInfo
|
* @param yShapeInfo
|
||||||
*/
|
*/
|
||||||
void execReduce3Scalar(PointerPointer extraPointers, int opNum,
|
void execReduce3Scalar(PointerPointer extraPointers, int opNum,
|
||||||
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo,
|
OpaqueDataBuffer x,
|
||||||
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
||||||
Pointer extraParamsVals,
|
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
||||||
Pointer y, @Cast("Nd4jLong *") LongPointer yShapeInfo,
|
Pointer extraParamsVals,
|
||||||
Pointer dy, @Cast("Nd4jLong *") LongPointer dyShapeInfo,
|
OpaqueDataBuffer y,
|
||||||
Pointer z, @Cast("Nd4jLong *") LongPointer zShapeInfo,
|
@Cast("Nd4jLong *") LongPointer yShapeInfo,
|
||||||
Pointer dz, @Cast("Nd4jLong *") LongPointer dzShapeInfo);
|
@Cast("Nd4jLong *") LongPointer dyShapeInfo,
|
||||||
|
OpaqueDataBuffer z,
|
||||||
|
@Cast("Nd4jLong *") LongPointer zShapeInfo,
|
||||||
|
@Cast("Nd4jLong *") LongPointer dzShapeInfo);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param opNum
|
* @param opNum
|
||||||
|
@ -351,29 +329,37 @@ public interface NativeOps {
|
||||||
*/
|
*/
|
||||||
void execReduce3Tad(PointerPointer extraPointers,
|
void execReduce3Tad(PointerPointer extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo,
|
OpaqueDataBuffer x,
|
||||||
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
||||||
|
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
||||||
Pointer extraParamsVals,
|
Pointer extraParamsVals,
|
||||||
Pointer y, @Cast("Nd4jLong *") LongPointer yShapeInfo,
|
OpaqueDataBuffer y,
|
||||||
Pointer dy, @Cast("Nd4jLong *") LongPointer dyShapeInfo,
|
@Cast("Nd4jLong *") LongPointer yShapeInfo,
|
||||||
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfoBuffer,
|
@Cast("Nd4jLong *") LongPointer dyShapeInfo,
|
||||||
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfoBuffer,
|
OpaqueDataBuffer result,
|
||||||
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape,
|
@Cast("Nd4jLong *") LongPointer resultShapeInfoBuffer,
|
||||||
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape,
|
@Cast("Nd4jLong *") LongPointer dresultShapeInfoBuffer,
|
||||||
@Cast("Nd4jLong *") LongPointer tadOnlyShapeInfo, @Cast("Nd4jLong *") LongPointer tadOffsets,
|
OpaqueDataBuffer hDimension,
|
||||||
@Cast("Nd4jLong *") LongPointer yTadOnlyShapeInfo, @Cast("Nd4jLong *") LongPointer yTadOffsets);
|
@Cast("Nd4jLong *") LongPointer hDimensionShape,
|
||||||
|
@Cast("Nd4jLong *") LongPointer dDimensionShape,
|
||||||
|
@Cast("Nd4jLong *") LongPointer tadOnlyShapeInfo, @Cast("Nd4jLong *") LongPointer tadOffsets,
|
||||||
|
@Cast("Nd4jLong *") LongPointer yTadOnlyShapeInfo, @Cast("Nd4jLong *") LongPointer yTadOffsets);
|
||||||
|
|
||||||
void execReduce3All(PointerPointer extraPointers,
|
void execReduce3All(PointerPointer extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo,
|
OpaqueDataBuffer x,
|
||||||
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
||||||
|
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
||||||
Pointer extraParamsVals,
|
Pointer extraParamsVals,
|
||||||
Pointer y, @Cast("Nd4jLong *") LongPointer yShapeInfo,
|
OpaqueDataBuffer y,
|
||||||
Pointer dy, @Cast("Nd4jLong *") LongPointer dyShapeInfo,
|
@Cast("Nd4jLong *") LongPointer yShapeInfo,
|
||||||
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfoBuffer,
|
@Cast("Nd4jLong *") LongPointer dyShapeInfo,
|
||||||
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfoBuffer,
|
OpaqueDataBuffer result,
|
||||||
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape,
|
@Cast("Nd4jLong *") LongPointer resultShapeInfoBuffer,
|
||||||
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape,
|
@Cast("Nd4jLong *") LongPointer dresultShapeInfoBuffer,
|
||||||
|
OpaqueDataBuffer hDimension,
|
||||||
|
@Cast("Nd4jLong *") LongPointer hDimensionShape,
|
||||||
|
@Cast("Nd4jLong *") LongPointer dDimensionShape,
|
||||||
@Cast("Nd4jLong *") LongPointer xTadShape,
|
@Cast("Nd4jLong *") LongPointer xTadShape,
|
||||||
@Cast("Nd4jLong *") LongPointer xOffsets,
|
@Cast("Nd4jLong *") LongPointer xOffsets,
|
||||||
@Cast("Nd4jLong *") LongPointer yTadShape,
|
@Cast("Nd4jLong *") LongPointer yTadShape,
|
||||||
|
@ -391,22 +377,28 @@ public interface NativeOps {
|
||||||
*/
|
*/
|
||||||
void execScalar(PointerPointer extraPointers,
|
void execScalar(PointerPointer extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo,
|
OpaqueDataBuffer x,
|
||||||
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
||||||
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
||||||
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
OpaqueDataBuffer result,
|
||||||
Pointer scalar, @Cast("Nd4jLong *") LongPointer scalarShapeInfo,
|
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
||||||
Pointer dscalar, @Cast("Nd4jLong *") LongPointer dscalarShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
||||||
|
OpaqueDataBuffer scalar,
|
||||||
|
@Cast("Nd4jLong *") LongPointer scalarShapeInfo,
|
||||||
|
@Cast("Nd4jLong *") LongPointer dscalarShapeInfo,
|
||||||
Pointer extraParams);
|
Pointer extraParams);
|
||||||
|
|
||||||
void execScalarBool(PointerPointer extraPointers,
|
void execScalarBool(PointerPointer extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo,
|
OpaqueDataBuffer x,
|
||||||
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
||||||
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
||||||
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
OpaqueDataBuffer result,
|
||||||
Pointer scalar, @Cast("Nd4jLong *") LongPointer scalarShapeInfo,
|
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
||||||
Pointer dscalar, @Cast("Nd4jLong *") LongPointer dscalarShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
||||||
|
OpaqueDataBuffer scalar,
|
||||||
|
@Cast("Nd4jLong *") LongPointer scalarShapeInfo,
|
||||||
|
@Cast("Nd4jLong *") LongPointer dscalarShapeInfo,
|
||||||
Pointer extraParams);
|
Pointer extraParams);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -418,11 +410,13 @@ public interface NativeOps {
|
||||||
*/
|
*/
|
||||||
void execSummaryStatsScalar(PointerPointer extraPointers,
|
void execSummaryStatsScalar(PointerPointer extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo,
|
OpaqueDataBuffer x,
|
||||||
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
||||||
|
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
||||||
Pointer extraParams,
|
Pointer extraParams,
|
||||||
Pointer z, @Cast("Nd4jLong *") LongPointer zShapeInfo,
|
OpaqueDataBuffer z,
|
||||||
Pointer dz, @Cast("Nd4jLong *") LongPointer dzShapeInfo,
|
@Cast("Nd4jLong *") LongPointer zShapeInfo,
|
||||||
|
@Cast("Nd4jLong *") LongPointer dzShapeInfo,
|
||||||
boolean biasCorrected);
|
boolean biasCorrected);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -436,11 +430,13 @@ public interface NativeOps {
|
||||||
*/
|
*/
|
||||||
void execSummaryStats(PointerPointer extraPointers,
|
void execSummaryStats(PointerPointer extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo,
|
OpaqueDataBuffer x,
|
||||||
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
||||||
|
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
||||||
Pointer extraParams,
|
Pointer extraParams,
|
||||||
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
OpaqueDataBuffer result,
|
||||||
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
||||||
|
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
||||||
boolean biasCorrected);
|
boolean biasCorrected);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -454,17 +450,20 @@ public interface NativeOps {
|
||||||
* @param dimensionLength
|
* @param dimensionLength
|
||||||
*/
|
*/
|
||||||
void execSummaryStatsTad(PointerPointer extraPointers,
|
void execSummaryStatsTad(PointerPointer extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo,
|
OpaqueDataBuffer x,
|
||||||
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
||||||
Pointer extraParams,
|
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
||||||
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfoBuffer,
|
Pointer extraParams,
|
||||||
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfoBuffer,
|
OpaqueDataBuffer result,
|
||||||
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape,
|
@Cast("Nd4jLong *") LongPointer resultShapeInfoBuffer,
|
||||||
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape,
|
@Cast("Nd4jLong *") LongPointer dresultShapeInfoBuffer,
|
||||||
boolean biasCorrected,
|
OpaqueDataBuffer hDimension,
|
||||||
@Cast("Nd4jLong *") LongPointer tadShapeInfo,
|
@Cast("Nd4jLong *") LongPointer hDimensionShape,
|
||||||
@Cast("Nd4jLong *") LongPointer tadOffsets);
|
@Cast("Nd4jLong *") LongPointer dDimensionShape,
|
||||||
|
boolean biasCorrected,
|
||||||
|
@Cast("Nd4jLong *") LongPointer tadShapeInfo,
|
||||||
|
@Cast("Nd4jLong *") LongPointer tadOffsets);
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -478,43 +477,53 @@ public interface NativeOps {
|
||||||
*/
|
*/
|
||||||
void execTransformFloat(PointerPointer extraPointers,
|
void execTransformFloat(PointerPointer extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo,
|
OpaqueDataBuffer x,
|
||||||
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
||||||
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
||||||
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
OpaqueDataBuffer result,
|
||||||
|
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
||||||
|
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
||||||
Pointer extraParams);
|
Pointer extraParams);
|
||||||
|
|
||||||
void execTransformSame(PointerPointer extraPointers,
|
void execTransformSame(PointerPointer extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo,
|
OpaqueDataBuffer x,
|
||||||
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
||||||
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
||||||
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
OpaqueDataBuffer result,
|
||||||
|
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
||||||
|
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
||||||
Pointer extraParams);
|
Pointer extraParams);
|
||||||
|
|
||||||
void execTransformStrict(PointerPointer extraPointers,
|
void execTransformStrict(PointerPointer extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo,
|
OpaqueDataBuffer x,
|
||||||
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
||||||
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
||||||
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
OpaqueDataBuffer result,
|
||||||
Pointer extraParams);
|
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
||||||
|
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
||||||
|
Pointer extraParams);
|
||||||
|
|
||||||
void execTransformBool(PointerPointer extraPointers,
|
void execTransformBool(PointerPointer extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo,
|
OpaqueDataBuffer x,
|
||||||
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
||||||
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
||||||
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
OpaqueDataBuffer result,
|
||||||
Pointer extraParams);
|
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
||||||
|
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
||||||
|
Pointer extraParams);
|
||||||
|
|
||||||
void execTransformAny(PointerPointer extraPointers,
|
void execTransformAny(PointerPointer extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo,
|
OpaqueDataBuffer x,
|
||||||
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
||||||
Pointer result, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
||||||
Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
OpaqueDataBuffer result,
|
||||||
Pointer extraParams);
|
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
||||||
|
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
||||||
|
Pointer extraParams);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ScalarOp along dimension
|
* ScalarOp along dimension
|
||||||
|
@ -532,31 +541,43 @@ public interface NativeOps {
|
||||||
*/
|
*/
|
||||||
void execScalarTad(PointerPointer extraPointers,
|
void execScalarTad(PointerPointer extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo,
|
OpaqueDataBuffer x,
|
||||||
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
||||||
Pointer z, @Cast("Nd4jLong *") LongPointer zShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
||||||
Pointer dz, @Cast("Nd4jLong *") LongPointer dzShapeInfo,
|
OpaqueDataBuffer z,
|
||||||
Pointer scalars, @Cast("Nd4jLong *") LongPointer scalarShapeInfo,
|
@Cast("Nd4jLong *") LongPointer zShapeInfo,
|
||||||
Pointer dscalars, @Cast("Nd4jLong *") LongPointer dscalarShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dzShapeInfo,
|
||||||
|
OpaqueDataBuffer scalars,
|
||||||
|
@Cast("Nd4jLong *") LongPointer scalarShapeInfo,
|
||||||
|
@Cast("Nd4jLong *") LongPointer dscalarShapeInfo,
|
||||||
Pointer extraParams,
|
Pointer extraParams,
|
||||||
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape,
|
OpaqueDataBuffer hDimension,
|
||||||
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape,
|
@Cast("Nd4jLong *") LongPointer hDimensionShape,
|
||||||
@Cast("Nd4jLong *") LongPointer tadShapeInfo, @Cast("Nd4jLong *") LongPointer tadOffsets,
|
@Cast("Nd4jLong *") LongPointer dDimensionShape,
|
||||||
@Cast("Nd4jLong *") LongPointer tadShapeInfoZ, @Cast("Nd4jLong *") LongPointer tadOffsetsZ);
|
@Cast("Nd4jLong *") LongPointer tadShapeInfo,
|
||||||
|
@Cast("Nd4jLong *") LongPointer tadOffsets,
|
||||||
|
@Cast("Nd4jLong *") LongPointer tadShapeInfoZ,
|
||||||
|
@Cast("Nd4jLong *") LongPointer tadOffsetsZ);
|
||||||
|
|
||||||
void execScalarBoolTad(PointerPointer extraPointers,
|
void execScalarBoolTad(PointerPointer extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo,
|
OpaqueDataBuffer x,
|
||||||
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
||||||
Pointer z, @Cast("Nd4jLong *") LongPointer zShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
||||||
Pointer dz, @Cast("Nd4jLong *") LongPointer dzShapeInfo,
|
OpaqueDataBuffer z,
|
||||||
Pointer scalars, @Cast("Nd4jLong *") LongPointer scalarShapeInfo,
|
@Cast("Nd4jLong *") LongPointer zShapeInfo,
|
||||||
Pointer dscalars, @Cast("Nd4jLong *") LongPointer dscalarShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dzShapeInfo,
|
||||||
|
OpaqueDataBuffer scalars,
|
||||||
|
@Cast("Nd4jLong *") LongPointer scalarShapeInfo,
|
||||||
|
@Cast("Nd4jLong *") LongPointer dscalarShapeInfo,
|
||||||
Pointer extraParams,
|
Pointer extraParams,
|
||||||
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape,
|
OpaqueDataBuffer hDimension,
|
||||||
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape,
|
@Cast("Nd4jLong *") LongPointer hDimensionShape,
|
||||||
@Cast("Nd4jLong *") LongPointer tadShapeInfo, @Cast("Nd4jLong *") LongPointer tadOffsets,
|
@Cast("Nd4jLong *") LongPointer dDimensionShape,
|
||||||
@Cast("Nd4jLong *") LongPointer tadShapeInfoZ, @Cast("Nd4jLong *") LongPointer tadOffsetsZ);
|
@Cast("Nd4jLong *") LongPointer tadShapeInfo,
|
||||||
|
@Cast("Nd4jLong *") LongPointer tadOffsets,
|
||||||
|
@Cast("Nd4jLong *") LongPointer tadShapeInfoZ,
|
||||||
|
@Cast("Nd4jLong *") LongPointer tadOffsetsZ);
|
||||||
|
|
||||||
|
|
||||||
void specialConcat(PointerPointer extraPointers,
|
void specialConcat(PointerPointer extraPointers,
|
||||||
|
@ -675,10 +696,12 @@ public interface NativeOps {
|
||||||
///////////////
|
///////////////
|
||||||
|
|
||||||
void pullRows(PointerPointer extraPointers,
|
void pullRows(PointerPointer extraPointers,
|
||||||
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeInfo,
|
OpaqueDataBuffer x,
|
||||||
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
||||||
Pointer z, @Cast("Nd4jLong *") LongPointer zShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
||||||
Pointer dz, @Cast("Nd4jLong *") LongPointer dzShapeInfo,
|
OpaqueDataBuffer z,
|
||||||
|
@Cast("Nd4jLong *") LongPointer zShapeInfo,
|
||||||
|
@Cast("Nd4jLong *") LongPointer dzShapeInfo,
|
||||||
long n,
|
long n,
|
||||||
@Cast("Nd4jLong *") LongPointer indexes,
|
@Cast("Nd4jLong *") LongPointer indexes,
|
||||||
@Cast("Nd4jLong *") LongPointer tadShapeInfo,
|
@Cast("Nd4jLong *") LongPointer tadShapeInfo,
|
||||||
|
@ -777,28 +800,34 @@ public interface NativeOps {
|
||||||
void execRandom(PointerPointer extraPointers,
|
void execRandom(PointerPointer extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Pointer state,
|
Pointer state,
|
||||||
Pointer z, @Cast("Nd4jLong *") LongPointer zShapeBuffer,
|
OpaqueDataBuffer z,
|
||||||
Pointer dz, @Cast("Nd4jLong *") LongPointer dzShapeBuffer,
|
@Cast("Nd4jLong *") LongPointer zShapeBuffer,
|
||||||
|
@Cast("Nd4jLong *") LongPointer dzShapeBuffer,
|
||||||
Pointer extraArguments);
|
Pointer extraArguments);
|
||||||
|
|
||||||
void execRandom3(PointerPointer extraPointers,
|
void execRandom3(PointerPointer extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Pointer state,
|
Pointer state,
|
||||||
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeBuffer,
|
OpaqueDataBuffer x,
|
||||||
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeBuffer,
|
@Cast("Nd4jLong *") LongPointer xShapeBuffer,
|
||||||
Pointer y, @Cast("Nd4jLong *") LongPointer yShapeBuffer,
|
@Cast("Nd4jLong *") LongPointer dxShapeBuffer,
|
||||||
Pointer dy, @Cast("Nd4jLong *") LongPointer dyShapeBuffer,
|
OpaqueDataBuffer y,
|
||||||
Pointer z, @Cast("Nd4jLong *") LongPointer zShapeBuffer,
|
@Cast("Nd4jLong *") LongPointer yShapeBuffer,
|
||||||
Pointer dz, @Cast("Nd4jLong *") LongPointer dzShapeBuffer,
|
@Cast("Nd4jLong *") LongPointer dyShapeBuffer,
|
||||||
|
OpaqueDataBuffer z,
|
||||||
|
@Cast("Nd4jLong *") LongPointer zShapeBuffer,
|
||||||
|
@Cast("Nd4jLong *") LongPointer dzShapeBuffer,
|
||||||
Pointer extraArguments);
|
Pointer extraArguments);
|
||||||
|
|
||||||
void execRandom2(PointerPointer extraPointers,
|
void execRandom2(PointerPointer extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Pointer state,
|
Pointer state,
|
||||||
Pointer x, @Cast("Nd4jLong *") LongPointer xShapeBuffer,
|
OpaqueDataBuffer x,
|
||||||
Pointer dx, @Cast("Nd4jLong *") LongPointer dxShapeBuffer,
|
@Cast("Nd4jLong *") LongPointer xShapeBuffer,
|
||||||
Pointer z, @Cast("Nd4jLong *") LongPointer zShapeBuffer,
|
@Cast("Nd4jLong *") LongPointer dxShapeBuffer,
|
||||||
Pointer dz, @Cast("Nd4jLong *") LongPointer dzShapeBuffer,
|
OpaqueDataBuffer z,
|
||||||
|
@Cast("Nd4jLong *") LongPointer zShapeBuffer,
|
||||||
|
@Cast("Nd4jLong *") LongPointer dzShapeBuffer,
|
||||||
Pointer extraArguments);
|
Pointer extraArguments);
|
||||||
|
|
||||||
////////////////////
|
////////////////////
|
||||||
|
@ -967,11 +996,13 @@ public interface NativeOps {
|
||||||
|
|
||||||
|
|
||||||
void tear(PointerPointer extras,
|
void tear(PointerPointer extras,
|
||||||
Pointer tensor, @Cast("Nd4jLong *") LongPointer xShapeInfo,
|
OpaqueDataBuffer tensor,
|
||||||
Pointer dtensor, @Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
@Cast("Nd4jLong *") LongPointer xShapeInfo,
|
||||||
PointerPointer targets, @Cast("Nd4jLong *") LongPointer zShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dxShapeInfo,
|
||||||
@Cast("Nd4jLong *") LongPointer tadShapeInfo,
|
PointerPointer targets,
|
||||||
@Cast("Nd4jLong *") LongPointer tadOffsets);
|
@Cast("Nd4jLong *") LongPointer zShapeInfo,
|
||||||
|
@Cast("Nd4jLong *") LongPointer tadShapeInfo,
|
||||||
|
@Cast("Nd4jLong *") LongPointer tadOffsets);
|
||||||
|
|
||||||
|
|
||||||
long encodeBitmap(PointerPointer extraPointers, Pointer dx, LongPointer xShapeInfo, long N, IntPointer dz, float threshold);
|
long encodeBitmap(PointerPointer extraPointers, Pointer dx, LongPointer xShapeInfo, long N, IntPointer dz, float threshold);
|
||||||
|
@ -1121,6 +1152,8 @@ public interface NativeOps {
|
||||||
void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer);
|
void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer);
|
||||||
void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
|
void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
|
||||||
void setGraphContextOutputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
|
void setGraphContextOutputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
|
||||||
|
void setGraphContextInputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer databuffer, Pointer shapeInfo, Pointer specialShapeInfo);
|
||||||
|
void setGraphContextOutputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer databuffer, Pointer shapeInfo, Pointer specialShapeInfo);
|
||||||
void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments);
|
void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments);
|
||||||
void setGraphContextIArguments(OpaqueContext ptr, LongPointer arguments, int numberOfArguments);
|
void setGraphContextIArguments(OpaqueContext ptr, LongPointer arguments, int numberOfArguments);
|
||||||
void setGraphContextBArguments(OpaqueContext ptr, BooleanPointer arguments, int numberOfArguments);
|
void setGraphContextBArguments(OpaqueContext ptr, BooleanPointer arguments, int numberOfArguments);
|
||||||
|
@ -1162,4 +1195,27 @@ public interface NativeOps {
|
||||||
|
|
||||||
boolean isMinimalRequirementsMet();
|
boolean isMinimalRequirementsMet();
|
||||||
boolean isOptimalRequirementsMet();
|
boolean isOptimalRequirementsMet();
|
||||||
|
|
||||||
|
|
||||||
|
OpaqueDataBuffer allocateDataBuffer(long elements, int dataType, boolean allocateBoth);
|
||||||
|
OpaqueDataBuffer dbCreateView(OpaqueDataBuffer dataBuffer, long length, long offset);
|
||||||
|
Pointer dbPrimaryBuffer(OpaqueDataBuffer dataBuffer);
|
||||||
|
Pointer dbSpecialBuffer(OpaqueDataBuffer dataBuffer);
|
||||||
|
void dbExpandBuffer(OpaqueDataBuffer dataBuffer, long elements);
|
||||||
|
void dbAllocatePrimaryBuffer(OpaqueDataBuffer dataBuffer);
|
||||||
|
void dbAllocateSpecialBuffer(OpaqueDataBuffer dataBuffer);
|
||||||
|
void dbSetPrimaryBuffer(OpaqueDataBuffer dataBuffer, Pointer primaryBuffer, long numBytes);
|
||||||
|
void dbSetSpecialBuffer(OpaqueDataBuffer dataBuffer, Pointer specialBuffer, long numBytes);
|
||||||
|
void dbSyncToSpecial(OpaqueDataBuffer dataBuffer);
|
||||||
|
void dbSyncToPrimary(OpaqueDataBuffer dataBuffer);
|
||||||
|
void dbTickHostRead(OpaqueDataBuffer dataBuffer);
|
||||||
|
void dbTickHostWrite(OpaqueDataBuffer dataBuffer);
|
||||||
|
void dbTickDeviceRead(OpaqueDataBuffer dataBuffer);
|
||||||
|
void dbTickDeviceWrite(OpaqueDataBuffer dataBuffer);
|
||||||
|
void deleteDataBuffer(OpaqueDataBuffer dataBuffer);
|
||||||
|
void dbClose(OpaqueDataBuffer dataBuffer);
|
||||||
|
int dbLocality(OpaqueDataBuffer dataBuffer);
|
||||||
|
int dbDeviceId(OpaqueDataBuffer dataBuffer);
|
||||||
|
void dbSetDeviceId(OpaqueDataBuffer dataBuffer, int deviceId);
|
||||||
|
void dbExpand(OpaqueDataBuffer dataBuffer, long newLength);
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,206 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.nativeblas;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
|
import lombok.val;
|
||||||
|
import org.bytedeco.javacpp.Pointer;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This class is a opaque pointer to InteropDataBuffer, used for Java/C++ interop related to INDArray DataBuffer
|
||||||
|
*
|
||||||
|
* @author saudet
|
||||||
|
*/
|
||||||
|
public class OpaqueDataBuffer extends Pointer {
|
||||||
|
// TODO: make this configurable
|
||||||
|
private static final int MAX_TRIES = 3;
|
||||||
|
|
||||||
|
public OpaqueDataBuffer(Pointer p) { super(p); }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method allocates new InteropDataBuffer and returns pointer to it
|
||||||
|
* @param numElements
|
||||||
|
* @param dataType
|
||||||
|
* @param allocateBoth
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public static OpaqueDataBuffer allocateDataBuffer(long numElements, @NonNull DataType dataType, boolean allocateBoth) {
|
||||||
|
OpaqueDataBuffer buffer = null;
|
||||||
|
int ec = 0;
|
||||||
|
String em = null;
|
||||||
|
|
||||||
|
for (int t = 0; t < MAX_TRIES; t++) {
|
||||||
|
try {
|
||||||
|
// try to allocate data buffer
|
||||||
|
buffer = NativeOpsHolder.getInstance().getDeviceNativeOps().allocateDataBuffer(numElements, dataType.toInt(), allocateBoth);
|
||||||
|
|
||||||
|
// check error code
|
||||||
|
ec = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode();
|
||||||
|
if (ec != 0) {
|
||||||
|
if (em == null)
|
||||||
|
em = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage();
|
||||||
|
|
||||||
|
// if allocation failed it might be caused by casual OOM, so we'll try GC
|
||||||
|
System.gc();
|
||||||
|
} else {
|
||||||
|
// just return the buffer
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if MAX_TRIES is over, we'll just throw an exception
|
||||||
|
throw new RuntimeException("Allocation failed: [" + em + "]");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method expands buffer, and copies content to the new buffer
|
||||||
|
*
|
||||||
|
* PLEASE NOTE: if InteropDataBuffer doesn't own actual buffers - original pointers won't be released
|
||||||
|
* @param numElements
|
||||||
|
*/
|
||||||
|
public void expand(long numElements) {
|
||||||
|
int ec = 0;
|
||||||
|
String em = null;
|
||||||
|
|
||||||
|
for (int t = 0; t < MAX_TRIES; t++) {
|
||||||
|
try {
|
||||||
|
// try to expand the buffer
|
||||||
|
NativeOpsHolder.getInstance().getDeviceNativeOps().dbExpand(this, numElements);
|
||||||
|
|
||||||
|
// check error code
|
||||||
|
ec = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode();
|
||||||
|
if (ec != 0) {
|
||||||
|
if (em == null)
|
||||||
|
em = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage();
|
||||||
|
|
||||||
|
// if expansion failed it might be caused by casual OOM, so we'll try GC
|
||||||
|
System.gc();
|
||||||
|
} else {
|
||||||
|
// just return
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if MAX_TRIES is over, we'll just throw an exception
|
||||||
|
throw new RuntimeException("DataBuffer expansion failed: [" + em + "]");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method creates a view out of this InteropDataBuffer
|
||||||
|
*
|
||||||
|
* @param bytesLength
|
||||||
|
* @param bytesOffset
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public OpaqueDataBuffer createView(long bytesLength, long bytesOffset) {
|
||||||
|
OpaqueDataBuffer buffer = null;
|
||||||
|
int ec = 0;
|
||||||
|
String em = null;
|
||||||
|
|
||||||
|
for (int t = 0; t < MAX_TRIES; t++) {
|
||||||
|
try {
|
||||||
|
buffer = NativeOpsHolder.getInstance().getDeviceNativeOps().dbCreateView(this, bytesLength, bytesOffset);
|
||||||
|
|
||||||
|
// check error code
|
||||||
|
ec = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode();
|
||||||
|
if (ec != 0) {
|
||||||
|
if (em == null)
|
||||||
|
em = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage();
|
||||||
|
|
||||||
|
// if view creation failed it might be caused by casual OOM, so we'll try GC
|
||||||
|
System.gc();
|
||||||
|
} else {
|
||||||
|
// just return
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if MAX_TRIES is over, we'll just throw an exception
|
||||||
|
throw new RuntimeException("DataBuffer expansion failed: [" + em + "]");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method returns pointer to linear buffer, primary one.
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public Pointer primaryBuffer() {
|
||||||
|
return NativeOpsHolder.getInstance().getDeviceNativeOps().dbPrimaryBuffer(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method returns pointer to special buffer, device one, if any.
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public Pointer specialBuffer() {
|
||||||
|
return NativeOpsHolder.getInstance().getDeviceNativeOps().dbSpecialBuffer(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method returns deviceId of this DataBuffer
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public int deviceId() {
|
||||||
|
return NativeOpsHolder.getInstance().getDeviceNativeOps().dbDeviceId(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method allows to set external pointer as primary buffer.
|
||||||
|
*
|
||||||
|
* PLEASE NOTE: if InteropDataBuffer owns current memory buffer, it will be released
|
||||||
|
* @param ptr
|
||||||
|
* @param numElements
|
||||||
|
*/
|
||||||
|
public void setPrimaryBuffer(Pointer ptr, long numElements) {
|
||||||
|
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetPrimaryBuffer(this, ptr, numElements);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method allows to set external pointer as primary buffer.
|
||||||
|
*
|
||||||
|
* PLEASE NOTE: if InteropDataBuffer owns current memory buffer, it will be released
|
||||||
|
* @param ptr
|
||||||
|
* @param numElements
|
||||||
|
*/
|
||||||
|
public void setSpecialBuffer(Pointer ptr, long numElements) {
|
||||||
|
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetSpecialBuffer(this, ptr, numElements);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method synchronizes device memory
|
||||||
|
*/
|
||||||
|
public void syncToSpecial() {
|
||||||
|
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToSpecial(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method synchronizes host memory
|
||||||
|
*/
|
||||||
|
public void syncToPrimary() {
|
||||||
|
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToPrimary(this);
|
||||||
|
}
|
||||||
|
}
|
|
@ -253,6 +253,7 @@
|
||||||
<version>${cuda.version}-${cudnn.version}-${javacpp-presets.cuda.version}</version>
|
<version>${cuda.version}-${cudnn.version}-${javacpp-presets.cuda.version}</version>
|
||||||
<classifier>${dependency.platform}</classifier>
|
<classifier>${dependency.platform}</classifier>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<!--
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
<artifactId>libnd4j</artifactId>
|
<artifactId>libnd4j</artifactId>
|
||||||
|
@ -261,6 +262,7 @@
|
||||||
<classifier>${javacpp.platform}-cuda-${cuda.version}</classifier>
|
<classifier>${javacpp.platform}-cuda-${cuda.version}</classifier>
|
||||||
<scope>provided</scope>
|
<scope>provided</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
-->
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>junit</groupId>
|
<groupId>junit</groupId>
|
||||||
<artifactId>junit</artifactId>
|
<artifactId>junit</artifactId>
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.jita.allocator.impl;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
|
import lombok.val;
|
||||||
import org.bytedeco.javacpp.Pointer;
|
import org.bytedeco.javacpp.Pointer;
|
||||||
import org.nd4j.jita.allocator.enums.AllocationStatus;
|
import org.nd4j.jita.allocator.enums.AllocationStatus;
|
||||||
import org.nd4j.jita.allocator.garbage.GarbageBufferReference;
|
import org.nd4j.jita.allocator.garbage.GarbageBufferReference;
|
||||||
|
@ -29,9 +30,11 @@ import org.nd4j.jita.allocator.time.providers.MillisecondsProvider;
|
||||||
import org.nd4j.jita.allocator.time.providers.OperativeProvider;
|
import org.nd4j.jita.allocator.time.providers.OperativeProvider;
|
||||||
import org.nd4j.linalg.api.buffer.BaseDataBuffer;
|
import org.nd4j.linalg.api.buffer.BaseDataBuffer;
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.jcublas.context.CudaContext;
|
import org.nd4j.linalg.jcublas.context.CudaContext;
|
||||||
import org.nd4j.nativeblas.NativeOps;
|
import org.nd4j.nativeblas.NativeOps;
|
||||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
|
import org.nd4j.nativeblas.OpaqueDataBuffer;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
@ -54,8 +57,8 @@ import java.util.concurrent.locks.ReentrantLock;
|
||||||
public class AllocationPoint {
|
public class AllocationPoint {
|
||||||
private static Logger log = LoggerFactory.getLogger(AllocationPoint.class);
|
private static Logger log = LoggerFactory.getLogger(AllocationPoint.class);
|
||||||
|
|
||||||
// thread safety is guaranteed by cudaLock
|
@Getter
|
||||||
private volatile PointersPair pointerInfo;
|
private OpaqueDataBuffer ptrDataBuffer;
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
@Setter
|
@Setter
|
||||||
|
@ -104,33 +107,27 @@ public class AllocationPoint {
|
||||||
*/
|
*/
|
||||||
private volatile int deviceId;
|
private volatile int deviceId;
|
||||||
|
|
||||||
public AllocationPoint() {
|
private long bytes;
|
||||||
//
|
|
||||||
|
public AllocationPoint(@NonNull OpaqueDataBuffer opaqueDataBuffer, long bytes) {
|
||||||
|
ptrDataBuffer = opaqueDataBuffer;
|
||||||
|
this.bytes = bytes;
|
||||||
|
objectId = Nd4j.getDeallocatorService().nextValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
public void acquireLock() {
|
public void setPointers(Pointer primary, Pointer special, long numberOfElements) {
|
||||||
//lock.lock();
|
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetPrimaryBuffer(ptrDataBuffer, primary, numberOfElements);
|
||||||
}
|
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetSpecialBuffer(ptrDataBuffer, special, numberOfElements);
|
||||||
|
|
||||||
public void releaseLock() {
|
|
||||||
//lock.unlock();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public int getDeviceId() {
|
public int getDeviceId() {
|
||||||
return deviceId;
|
return ptrDataBuffer.deviceId();
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setDeviceId(int deviceId) {
|
public void setDeviceId(int deviceId) {
|
||||||
this.deviceId = deviceId;
|
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetDeviceId(ptrDataBuffer, deviceId);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
We assume 1D memory chunk allocations.
|
|
||||||
*/
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
private AllocationShape shape;
|
|
||||||
|
|
||||||
private AtomicBoolean enqueued = new AtomicBoolean(false);
|
private AtomicBoolean enqueued = new AtomicBoolean(false);
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
|
@ -164,7 +161,7 @@ public class AllocationPoint {
|
||||||
}
|
}
|
||||||
|
|
||||||
public long getNumberOfBytes() {
|
public long getNumberOfBytes() {
|
||||||
return shape.getNumberOfBytes();
|
return bytes;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -220,67 +217,25 @@ public class AllocationPoint {
|
||||||
* This method returns CUDA pointer object for this allocation.
|
* This method returns CUDA pointer object for this allocation.
|
||||||
* It can be either device pointer or pinned memory pointer, or null.
|
* It can be either device pointer or pinned memory pointer, or null.
|
||||||
*
|
*
|
||||||
* PLEASE NOTE: Thread safety is guaranteed by reentrant read/write lock
|
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public Pointer getDevicePointer() {
|
public Pointer getDevicePointer() {
|
||||||
if (pointerInfo == null) {
|
return NativeOpsHolder.getInstance().getDeviceNativeOps().dbSpecialBuffer(ptrDataBuffer);
|
||||||
log.info("pointerInfo is null");
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
return pointerInfo.getDevicePointer();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method returns CUDA pointer object for this allocation.
|
* This method returns CUDA pointer object for this allocation.
|
||||||
* It can be either device pointer or pinned memory pointer, or null.
|
* It can be either device pointer or pinned memory pointer, or null.
|
||||||
*
|
*
|
||||||
* PLEASE NOTE: Thread safety is guaranteed by reentrant read/write lock
|
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public Pointer getHostPointer() {
|
public Pointer getHostPointer() {
|
||||||
if (pointerInfo == null)
|
return NativeOpsHolder.getInstance().getDeviceNativeOps().dbPrimaryBuffer(ptrDataBuffer);
|
||||||
return null;
|
|
||||||
|
|
||||||
return pointerInfo.getHostPointer();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This method sets CUDA pointer for this allocation.
|
|
||||||
* It can be either device pointer, or pinned memory pointer, or null.
|
|
||||||
*
|
|
||||||
* PLEASE NOTE: Thread safety is guaranteed by reentrant read/write lock
|
|
||||||
* @param pointerInfo CUDA pointers wrapped into DevicePointerInfo
|
|
||||||
*/
|
|
||||||
public void setPointers(@NonNull PointersPair pointerInfo) {
|
|
||||||
this.pointerInfo = pointerInfo;
|
|
||||||
}
|
|
||||||
|
|
||||||
public PointersPair getPointers() {
|
|
||||||
return this.pointerInfo;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public synchronized void tickDeviceRead() {
|
public synchronized void tickDeviceRead() {
|
||||||
// this.deviceTicks.incrementAndGet();
|
NativeOpsHolder.getInstance().getDeviceNativeOps().dbTickDeviceRead(ptrDataBuffer);
|
||||||
// this.timerShort.triggerEvent();
|
|
||||||
// this.timerLong.triggerEvent();
|
|
||||||
//this.deviceAccessTime.set(realTimeProvider.getCurrentTime());
|
|
||||||
this.accessDeviceRead = (timeProvider.getCurrentTime());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns time, in milliseconds, when this point was accessed on host side
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public synchronized long getHostReadTime() {
|
|
||||||
return accessHostRead;
|
|
||||||
};
|
|
||||||
|
|
||||||
public synchronized long getHostWriteTime() {
|
|
||||||
return accessHostWrite;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -302,7 +257,7 @@ public class AllocationPoint {
|
||||||
}
|
}
|
||||||
|
|
||||||
public synchronized void tickHostRead() {
|
public synchronized void tickHostRead() {
|
||||||
accessHostRead = (timeProvider.getCurrentTime());
|
NativeOpsHolder.getInstance().getDeviceNativeOps().dbTickHostRead(ptrDataBuffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -310,17 +265,14 @@ public class AllocationPoint {
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
public synchronized void tickDeviceWrite() {
|
public synchronized void tickDeviceWrite() {
|
||||||
// deviceAccessTime.set(realTimeProvider.getCurrentTime());
|
NativeOpsHolder.getInstance().getDeviceNativeOps().dbTickDeviceWrite(ptrDataBuffer);
|
||||||
tickDeviceRead();
|
|
||||||
accessDeviceWrite = (timeProvider.getCurrentTime());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method sets time when this point was changed on host
|
* This method sets time when this point was changed on host
|
||||||
*/
|
*/
|
||||||
public synchronized void tickHostWrite() {
|
public synchronized void tickHostWrite() {
|
||||||
tickHostRead();
|
NativeOpsHolder.getInstance().getDeviceNativeOps().dbTickHostWrite(ptrDataBuffer);
|
||||||
accessHostWrite = (timeProvider.getCurrentTime());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -329,10 +281,8 @@ public class AllocationPoint {
|
||||||
* @return true, if data is actual, false otherwise
|
* @return true, if data is actual, false otherwise
|
||||||
*/
|
*/
|
||||||
public synchronized boolean isActualOnHostSide() {
|
public synchronized boolean isActualOnHostSide() {
|
||||||
boolean result = accessHostWrite >= accessDeviceWrite
|
val s = NativeOpsHolder.getInstance().getDeviceNativeOps().dbLocality(ptrDataBuffer);
|
||||||
|| accessHostRead >= accessDeviceWrite;
|
return s <= 0;
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -341,9 +291,8 @@ public class AllocationPoint {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public synchronized boolean isActualOnDeviceSide() {
|
public synchronized boolean isActualOnDeviceSide() {
|
||||||
boolean result = accessDeviceWrite >= accessHostWrite
|
val s = NativeOpsHolder.getInstance().getDeviceNativeOps().dbLocality(ptrDataBuffer);
|
||||||
|| accessDeviceRead >= accessHostWrite;
|
return s >= 0;
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -355,6 +304,6 @@ public class AllocationPoint {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toString() {
|
public String toString() {
|
||||||
return "AllocationPoint{" + "deviceId=" + deviceId + ", objectId=" + objectId + ", shape=" + shape + '}';
|
return "AllocationPoint{" + "deviceId=" + deviceId + ", objectId=" + objectId + "}";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,12 +19,10 @@ package org.nd4j.jita.allocator.impl;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.apache.commons.lang3.RandomUtils;
|
|
||||||
import org.bytedeco.javacpp.Pointer;
|
import org.bytedeco.javacpp.Pointer;
|
||||||
import org.nd4j.jita.allocator.Allocator;
|
import org.nd4j.jita.allocator.Allocator;
|
||||||
import org.nd4j.jita.allocator.enums.Aggressiveness;
|
import org.nd4j.jita.allocator.enums.Aggressiveness;
|
||||||
import org.nd4j.jita.allocator.enums.AllocationStatus;
|
import org.nd4j.jita.allocator.enums.AllocationStatus;
|
||||||
import org.nd4j.jita.allocator.garbage.GarbageBufferReference;
|
|
||||||
import org.nd4j.jita.allocator.pointers.CudaPointer;
|
import org.nd4j.jita.allocator.pointers.CudaPointer;
|
||||||
import org.nd4j.jita.allocator.pointers.PointersPair;
|
import org.nd4j.jita.allocator.pointers.PointersPair;
|
||||||
import org.nd4j.jita.allocator.time.Ring;
|
import org.nd4j.jita.allocator.time.Ring;
|
||||||
|
@ -37,29 +35,25 @@ import org.nd4j.jita.flow.FlowController;
|
||||||
import org.nd4j.jita.handler.MemoryHandler;
|
import org.nd4j.jita.handler.MemoryHandler;
|
||||||
import org.nd4j.jita.handler.impl.CudaZeroHandler;
|
import org.nd4j.jita.handler.impl.CudaZeroHandler;
|
||||||
import org.nd4j.jita.workspace.CudaWorkspace;
|
import org.nd4j.jita.workspace.CudaWorkspace;
|
||||||
import org.nd4j.linalg.api.buffer.BaseDataBuffer;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.buffer.Utf8Buffer;
|
|
||||||
import org.nd4j.linalg.api.memory.enums.MemoryKind;
|
import org.nd4j.linalg.api.memory.enums.MemoryKind;
|
||||||
import org.nd4j.linalg.api.memory.pointers.PagedPointer;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.cache.ConstantHandler;
|
import org.nd4j.linalg.cache.ConstantHandler;
|
||||||
import org.nd4j.linalg.compression.CompressedDataBuffer;
|
import org.nd4j.linalg.compression.CompressedDataBuffer;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer;
|
||||||
|
import org.nd4j.linalg.jcublas.buffer.CudaUtf8Buffer;
|
||||||
import org.nd4j.linalg.jcublas.context.CudaContext;
|
import org.nd4j.linalg.jcublas.context.CudaContext;
|
||||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer;
|
|
||||||
|
|
||||||
import java.lang.ref.ReferenceQueue;
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
import java.util.concurrent.atomic.AtomicBoolean;
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
import java.util.concurrent.atomic.AtomicLong;
|
import java.util.concurrent.atomic.AtomicLong;
|
||||||
import java.util.concurrent.locks.LockSupport;
|
|
||||||
import java.util.concurrent.locks.ReentrantReadWriteLock;
|
import java.util.concurrent.locks.ReentrantReadWriteLock;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -285,16 +279,10 @@ public class AtomicAllocator implements Allocator {
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public Pointer getPointer(@NonNull DataBuffer buffer, CudaContext context) {
|
public Pointer getPointer(@NonNull DataBuffer buffer, CudaContext context) {
|
||||||
if (buffer instanceof Utf8Buffer)
|
|
||||||
return null;
|
|
||||||
|
|
||||||
return memoryHandler.getDevicePointer(buffer, context);
|
return memoryHandler.getDevicePointer(buffer, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Pointer getPointer(DataBuffer buffer) {
|
public Pointer getPointer(DataBuffer buffer) {
|
||||||
if (buffer instanceof Utf8Buffer)
|
|
||||||
return null;
|
|
||||||
|
|
||||||
return memoryHandler.getDevicePointer(buffer, getDeviceContext());
|
return memoryHandler.getDevicePointer(buffer, getDeviceContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -320,7 +308,7 @@ public class AtomicAllocator implements Allocator {
|
||||||
public Pointer getPointer(INDArray array, CudaContext context) {
|
public Pointer getPointer(INDArray array, CudaContext context) {
|
||||||
// DataBuffer buffer = array.data().originalDataBuffer() == null ? array.data() : array.data().originalDataBuffer();
|
// DataBuffer buffer = array.data().originalDataBuffer() == null ? array.data() : array.data().originalDataBuffer();
|
||||||
if (array.isEmpty() || array.isS())
|
if (array.isEmpty() || array.isS())
|
||||||
return null;
|
throw new UnsupportedOperationException("Pew-pew");
|
||||||
|
|
||||||
return memoryHandler.getDevicePointer(array.data(), context);
|
return memoryHandler.getDevicePointer(array.data(), context);
|
||||||
}
|
}
|
||||||
|
@ -372,20 +360,17 @@ public class AtomicAllocator implements Allocator {
|
||||||
@Override
|
@Override
|
||||||
public void synchronizeHostData(DataBuffer buffer) {
|
public void synchronizeHostData(DataBuffer buffer) {
|
||||||
// we don't want non-committed ops left behind
|
// we don't want non-committed ops left behind
|
||||||
//Nd4j.getExecutioner().push();
|
Nd4j.getExecutioner().commit();
|
||||||
|
|
||||||
// we don't synchronize constant buffers, since we assume they are always valid on host side
|
val oPtr = NativeOpsHolder.getInstance().getDeviceNativeOps().dbPrimaryBuffer(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer());
|
||||||
if (buffer.isConstant() || buffer.dataType() == DataType.UTF8 || AtomicAllocator.getInstance().getAllocationPoint(buffer).getPointers().getHostPointer() == null) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// we actually need synchronization only in device-dependant environment. no-op otherwise
|
// we actually need synchronization only in device-dependant environment. no-op otherwise. managed by native code
|
||||||
if (memoryHandler.isDeviceDependant()) {
|
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToPrimary(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer());
|
||||||
val point = getAllocationPoint(buffer.getTrackingPoint());
|
|
||||||
if (point == null)
|
val cPtr = NativeOpsHolder.getInstance().getDeviceNativeOps().dbPrimaryBuffer(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer());
|
||||||
throw new RuntimeException("AllocationPoint is NULL");
|
|
||||||
memoryHandler.synchronizeThreadDevice(Thread.currentThread().getId(), memoryHandler.getDeviceId(), point);
|
//assert oPtr.address() == cPtr.address();
|
||||||
}
|
//assert buffer.address() == oPtr.address();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -446,6 +431,7 @@ public class AtomicAllocator implements Allocator {
|
||||||
|
|
||||||
|
|
||||||
public AllocationPoint pickExternalBuffer(DataBuffer buffer) {
|
public AllocationPoint pickExternalBuffer(DataBuffer buffer) {
|
||||||
|
/**
|
||||||
AllocationPoint point = new AllocationPoint();
|
AllocationPoint point = new AllocationPoint();
|
||||||
Long allocId = objectsTracker.getAndIncrement();
|
Long allocId = objectsTracker.getAndIncrement();
|
||||||
point.setObjectId(allocId);
|
point.setObjectId(allocId);
|
||||||
|
@ -458,6 +444,9 @@ public class AtomicAllocator implements Allocator {
|
||||||
point.tickHostRead();
|
point.tickHostRead();
|
||||||
|
|
||||||
return point;
|
return point;
|
||||||
|
*/
|
||||||
|
|
||||||
|
throw new UnsupportedOperationException("Pew-pew");
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -469,69 +458,8 @@ public class AtomicAllocator implements Allocator {
|
||||||
* @param location
|
* @param location
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public AllocationPoint allocateMemory(DataBuffer buffer, AllocationShape requiredMemory, AllocationStatus location,
|
public AllocationPoint allocateMemory(DataBuffer buffer, AllocationShape requiredMemory, AllocationStatus location, boolean initialize) {
|
||||||
boolean initialize) {
|
throw new UnsupportedOperationException("Pew-pew");
|
||||||
AllocationPoint point = new AllocationPoint();
|
|
||||||
|
|
||||||
useTracker.set(System.currentTimeMillis());
|
|
||||||
|
|
||||||
// we use these longs as tracking codes for memory tracking
|
|
||||||
Long allocId = objectsTracker.getAndIncrement();
|
|
||||||
//point.attachBuffer(buffer);
|
|
||||||
point.setObjectId(allocId);
|
|
||||||
point.setShape(requiredMemory);
|
|
||||||
/*
|
|
||||||
if (buffer instanceof CudaIntDataBuffer) {
|
|
||||||
buffer.setConstant(true);
|
|
||||||
point.setConstant(true);
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
/*int numBuckets = configuration.getNumberOfGcThreads();
|
|
||||||
int bucketId = RandomUtils.nextInt(0, numBuckets);
|
|
||||||
|
|
||||||
GarbageBufferReference reference =
|
|
||||||
new GarbageBufferReference((BaseDataBuffer) buffer, queueMap.get(bucketId), point);*/
|
|
||||||
//point.attachReference(reference);
|
|
||||||
point.setDeviceId(-1);
|
|
||||||
|
|
||||||
if (buffer.isAttached()) {
|
|
||||||
long reqMem = AllocationUtils.getRequiredMemory(requiredMemory);
|
|
||||||
|
|
||||||
// workaround for init order
|
|
||||||
getMemoryHandler().getCudaContext();
|
|
||||||
point.setDeviceId(Nd4j.getAffinityManager().getDeviceForCurrentThread());
|
|
||||||
|
|
||||||
val workspace = (CudaWorkspace) Nd4j.getMemoryManager().getCurrentWorkspace();
|
|
||||||
|
|
||||||
val pair = new PointersPair();
|
|
||||||
val ptrDev = workspace.alloc(reqMem, MemoryKind.DEVICE, requiredMemory.getDataType(), initialize);
|
|
||||||
|
|
||||||
if (ptrDev != null) {
|
|
||||||
pair.setDevicePointer(ptrDev);
|
|
||||||
point.setAllocationStatus(AllocationStatus.DEVICE);
|
|
||||||
} else {
|
|
||||||
// we allocate initial host pointer only
|
|
||||||
val ptrHost = workspace.alloc(reqMem, MemoryKind.HOST, requiredMemory.getDataType(), initialize);
|
|
||||||
pair.setHostPointer(ptrHost);
|
|
||||||
|
|
||||||
pair.setDevicePointer(ptrHost);
|
|
||||||
point.setAllocationStatus(AllocationStatus.HOST);
|
|
||||||
}
|
|
||||||
|
|
||||||
point.setAttached(true);
|
|
||||||
|
|
||||||
point.setPointers(pair);
|
|
||||||
} else {
|
|
||||||
// we stay naive on PointersPair, we just don't know on this level, which pointers are set. MemoryHandler will be used for that
|
|
||||||
PointersPair pair = memoryHandler.alloc(location, point, requiredMemory, initialize);
|
|
||||||
point.setPointers(pair);
|
|
||||||
}
|
|
||||||
|
|
||||||
allocationsMap.put(allocId, point);
|
|
||||||
//point.tickHostRead();
|
|
||||||
point.tickDeviceWrite();
|
|
||||||
//point.setAllocationStatus(location);
|
|
||||||
return point;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -619,10 +547,11 @@ public class AtomicAllocator implements Allocator {
|
||||||
*/
|
*/
|
||||||
if (point.getBuffer() == null) {
|
if (point.getBuffer() == null) {
|
||||||
purgeZeroObject(bucketId, object, point, false);
|
purgeZeroObject(bucketId, object, point, false);
|
||||||
freeSpace.addAndGet(AllocationUtils.getRequiredMemory(point.getShape()));
|
//freeSpace.addAndGet(AllocationUtils.getRequiredMemory(point.getShape()));
|
||||||
|
throw new UnsupportedOperationException("Pew-pew");
|
||||||
|
|
||||||
elementsDropped.incrementAndGet();
|
//elementsDropped.incrementAndGet();
|
||||||
continue;
|
//continue;
|
||||||
} else {
|
} else {
|
||||||
elementsSurvived.incrementAndGet();
|
elementsSurvived.incrementAndGet();
|
||||||
}
|
}
|
||||||
|
@ -682,13 +611,14 @@ public class AtomicAllocator implements Allocator {
|
||||||
if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
|
if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
|
||||||
// we deallocate device memory
|
// we deallocate device memory
|
||||||
purgeDeviceObject(threadId, deviceId, object, point, false);
|
purgeDeviceObject(threadId, deviceId, object, point, false);
|
||||||
freeSpace.addAndGet(AllocationUtils.getRequiredMemory(point.getShape()));
|
//freeSpace.addAndGet(AllocationUtils.getRequiredMemory(point.getShape()));
|
||||||
|
|
||||||
// and we deallocate host memory, since object is dereferenced
|
// and we deallocate host memory, since object is dereferenced
|
||||||
purgeZeroObject(point.getBucketId(), object, point, false);
|
//purgeZeroObject(point.getBucketId(), object, point, false);
|
||||||
|
|
||||||
elementsDropped.incrementAndGet();
|
//elementsDropped.incrementAndGet();
|
||||||
continue;
|
//continue;
|
||||||
|
throw new UnsupportedOperationException("Pew-pew");
|
||||||
} ;
|
} ;
|
||||||
} else {
|
} else {
|
||||||
elementsSurvived.incrementAndGet();
|
elementsSurvived.incrementAndGet();
|
||||||
|
@ -1014,6 +944,31 @@ public class AtomicAllocator implements Allocator {
|
||||||
this.memoryHandler.memcpy(dstBuffer, srcBuffer);
|
this.memoryHandler.memcpy(dstBuffer, srcBuffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void tickHostWrite(DataBuffer buffer) {
|
||||||
|
getAllocationPoint(buffer).tickHostWrite();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void tickHostWrite(INDArray array) {
|
||||||
|
getAllocationPoint(array.data()).tickHostWrite();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void tickDeviceWrite(INDArray array) {
|
||||||
|
getAllocationPoint(array.data()).tickDeviceWrite();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public AllocationPoint getAllocationPoint(INDArray array) {
|
||||||
|
return getAllocationPoint(array.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public AllocationPoint getAllocationPoint(DataBuffer buffer) {
|
||||||
|
return ((BaseCudaDataBuffer) buffer).getAllocationPoint();
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method returns deviceId for current thread
|
* This method returns deviceId for current thread
|
||||||
* All values >= 0 are considered valid device IDs, all values < 0 are considered stubs.
|
* All values >= 0 are considered valid device IDs, all values < 0 are considered stubs.
|
||||||
|
@ -1031,48 +986,6 @@ public class AtomicAllocator implements Allocator {
|
||||||
return new CudaPointer(getDeviceId());
|
return new CudaPointer(getDeviceId());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void tickHostWrite(DataBuffer buffer) {
|
|
||||||
AllocationPoint point = getAllocationPoint(buffer.getTrackingPoint());
|
|
||||||
point.tickHostWrite();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void tickHostWrite(INDArray array) {
|
|
||||||
DataBuffer buffer =
|
|
||||||
array.data().originalDataBuffer() == null ? array.data() : array.data().originalDataBuffer();
|
|
||||||
|
|
||||||
tickHostWrite(buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void tickDeviceWrite(INDArray array) {
|
|
||||||
DataBuffer buffer =
|
|
||||||
array.data().originalDataBuffer() == null ? array.data() : array.data().originalDataBuffer();
|
|
||||||
AllocationPoint point = getAllocationPoint(buffer.getTrackingPoint());
|
|
||||||
|
|
||||||
point.tickDeviceWrite();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public AllocationPoint getAllocationPoint(INDArray array) {
|
|
||||||
if (array.isEmpty())
|
|
||||||
return null;
|
|
||||||
|
|
||||||
DataBuffer buffer = array.data().originalDataBuffer() == null ? array.data() : array.data().originalDataBuffer();
|
|
||||||
return getAllocationPoint(buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public AllocationPoint getAllocationPoint(DataBuffer buffer) {
|
|
||||||
if (buffer instanceof CompressedDataBuffer) {
|
|
||||||
log.warn("Trying to get AllocationPoint from CompressedDataBuffer");
|
|
||||||
throw new RuntimeException("AP CDB");
|
|
||||||
}
|
|
||||||
|
|
||||||
return getAllocationPoint(buffer.getTrackingPoint());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void registerAction(CudaContext context, INDArray result, INDArray... operands) {
|
public void registerAction(CudaContext context, INDArray result, INDArray... operands) {
|
||||||
memoryHandler.registerAction(context, result, operands);
|
memoryHandler.registerAction(context, result, operands);
|
||||||
|
|
|
@ -23,46 +23,21 @@ import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
import org.nd4j.linalg.api.memory.Deallocator;
|
import org.nd4j.linalg.api.memory.Deallocator;
|
||||||
|
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
|
import org.nd4j.nativeblas.OpaqueDataBuffer;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class CudaDeallocator implements Deallocator {
|
public class CudaDeallocator implements Deallocator {
|
||||||
|
|
||||||
private AllocationPoint point;
|
private OpaqueDataBuffer opaqueDataBuffer;
|
||||||
|
|
||||||
public CudaDeallocator(@NonNull BaseCudaDataBuffer buffer) {
|
public CudaDeallocator(@NonNull BaseCudaDataBuffer buffer) {
|
||||||
this.point = buffer.getAllocationPoint();
|
opaqueDataBuffer = buffer.getOpaqueDataBuffer();
|
||||||
if (this.point == null)
|
|
||||||
throw new RuntimeException();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void deallocate() {
|
public void deallocate() {
|
||||||
log.trace("Deallocating CUDA memory");
|
log.trace("Deallocating CUDA memory");
|
||||||
// skipping any allocation that is coming from workspace
|
NativeOpsHolder.getInstance().getDeviceNativeOps().deleteDataBuffer(opaqueDataBuffer);
|
||||||
if (point.isAttached() || point.isReleased()) {
|
|
||||||
// TODO: remove allocation point as well?
|
|
||||||
if (!AtomicAllocator.getInstance().allocationsMap().containsKey(point.getObjectId()))
|
|
||||||
return;
|
|
||||||
|
|
||||||
AtomicAllocator.getInstance().getFlowController().waitTillReleased(point);
|
|
||||||
|
|
||||||
AtomicAllocator.getInstance().getFlowController().getEventsProvider().storeEvent(point.getLastWriteEvent());
|
|
||||||
AtomicAllocator.getInstance().getFlowController().getEventsProvider().storeEvent(point.getLastReadEvent());
|
|
||||||
|
|
||||||
AtomicAllocator.getInstance().allocationsMap().remove(point.getObjectId());
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
//log.info("Purging {} bytes...", AllocationUtils.getRequiredMemory(point.getShape()));
|
|
||||||
if (point.getAllocationStatus() == AllocationStatus.HOST) {
|
|
||||||
AtomicAllocator.getInstance().purgeZeroObject(point.getBucketId(), point.getObjectId(), point, false);
|
|
||||||
} else if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
|
|
||||||
AtomicAllocator.getInstance().purgeDeviceObject(0L, point.getDeviceId(), point.getObjectId(), point, false);
|
|
||||||
|
|
||||||
// and we deallocate host memory, since object is dereferenced
|
|
||||||
AtomicAllocator.getInstance().purgeZeroObject(point.getBucketId(), point.getObjectId(), point, false);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.nd4j.jita.allocator.pointers.cuda;
|
package org.nd4j.jita.allocator.pointers.cuda;
|
||||||
|
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
|
import lombok.val;
|
||||||
import org.bytedeco.javacpp.Pointer;
|
import org.bytedeco.javacpp.Pointer;
|
||||||
import org.nd4j.jita.allocator.pointers.CudaPointer;
|
import org.nd4j.jita.allocator.pointers.CudaPointer;
|
||||||
import org.nd4j.linalg.exception.ND4JException;
|
import org.nd4j.linalg.exception.ND4JException;
|
||||||
|
@ -37,8 +38,9 @@ public class cudaStream_t extends CudaPointer {
|
||||||
NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
|
NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
|
||||||
int res = nativeOps.streamSynchronize(this);
|
int res = nativeOps.streamSynchronize(this);
|
||||||
|
|
||||||
if (nativeOps.lastErrorCode() != 0)
|
val ec = nativeOps.lastErrorCode();
|
||||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
if (ec != 0)
|
||||||
|
throw new RuntimeException(nativeOps.lastErrorMessage() + "; Error code: " + ec);
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
|
@ -129,7 +129,7 @@ public class ProtectedCudaConstantHandler implements ConstantHandler {
|
||||||
|
|
||||||
AllocationPoint point = AtomicAllocator.getInstance().getAllocationPoint(dataBuffer);
|
AllocationPoint point = AtomicAllocator.getInstance().getAllocationPoint(dataBuffer);
|
||||||
|
|
||||||
long requiredMemoryBytes = AllocationUtils.getRequiredMemory(point.getShape());
|
long requiredMemoryBytes = point.getNumberOfBytes();
|
||||||
val originalBytes = requiredMemoryBytes;
|
val originalBytes = requiredMemoryBytes;
|
||||||
requiredMemoryBytes += 8 - (requiredMemoryBytes % 8);
|
requiredMemoryBytes += 8 - (requiredMemoryBytes % 8);
|
||||||
|
|
||||||
|
@ -147,13 +147,13 @@ public class ProtectedCudaConstantHandler implements ConstantHandler {
|
||||||
if (currentOffset + requiredMemoryBytes >= MAX_CONSTANT_LENGTH || requiredMemoryBytes > MAX_BUFFER_LENGTH) {
|
if (currentOffset + requiredMemoryBytes >= MAX_CONSTANT_LENGTH || requiredMemoryBytes > MAX_BUFFER_LENGTH) {
|
||||||
if (point.getAllocationStatus() == AllocationStatus.HOST
|
if (point.getAllocationStatus() == AllocationStatus.HOST
|
||||||
&& CudaEnvironment.getInstance().getConfiguration().getMemoryModel() == Configuration.MemoryModel.DELAYED) {
|
&& CudaEnvironment.getInstance().getConfiguration().getMemoryModel() == Configuration.MemoryModel.DELAYED) {
|
||||||
AtomicAllocator.getInstance().getMemoryHandler().alloc(AllocationStatus.DEVICE, point, point.getShape(),
|
//AtomicAllocator.getInstance().getMemoryHandler().alloc(AllocationStatus.DEVICE, point, point.getShape(), false);
|
||||||
false);
|
throw new UnsupportedOperationException("Pew-pew");
|
||||||
}
|
}
|
||||||
|
|
||||||
val profD = PerformanceTracker.getInstance().helperStartTransaction();
|
val profD = PerformanceTracker.getInstance().helperStartTransaction();
|
||||||
|
|
||||||
if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(point.getPointers().getDevicePointer(), point.getPointers().getHostPointer(), originalBytes, 1, context.getSpecialStream()) == 0) {
|
if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(point.getDevicePointer(), point.getHostPointer(), originalBytes, 1, context.getSpecialStream()) == 0) {
|
||||||
throw new ND4JIllegalStateException("memcpyAsync failed");
|
throw new ND4JIllegalStateException("memcpyAsync failed");
|
||||||
}
|
}
|
||||||
flowController.commitTransfer(context.getSpecialStream());
|
flowController.commitTransfer(context.getSpecialStream());
|
||||||
|
@ -176,14 +176,13 @@ public class ProtectedCudaConstantHandler implements ConstantHandler {
|
||||||
if (currentOffset >= MAX_CONSTANT_LENGTH) {
|
if (currentOffset >= MAX_CONSTANT_LENGTH) {
|
||||||
if (point.getAllocationStatus() == AllocationStatus.HOST
|
if (point.getAllocationStatus() == AllocationStatus.HOST
|
||||||
&& CudaEnvironment.getInstance().getConfiguration().getMemoryModel() == Configuration.MemoryModel.DELAYED) {
|
&& CudaEnvironment.getInstance().getConfiguration().getMemoryModel() == Configuration.MemoryModel.DELAYED) {
|
||||||
AtomicAllocator.getInstance().getMemoryHandler().alloc(AllocationStatus.DEVICE, point, point.getShape(),
|
//AtomicAllocator.getInstance().getMemoryHandler().alloc(AllocationStatus.DEVICE, point, point.getShape(), false);
|
||||||
false);
|
throw new UnsupportedOperationException("Pew-pew");
|
||||||
}
|
}
|
||||||
|
|
||||||
val profD = PerformanceTracker.getInstance().helperStartTransaction();
|
val profD = PerformanceTracker.getInstance().helperStartTransaction();
|
||||||
|
|
||||||
if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(point.getPointers().getDevicePointer(), point.getPointers().getHostPointer(),
|
if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(point.getDevicePointer(), point.getHostPointer(), originalBytes, 1, context.getSpecialStream()) == 0) {
|
||||||
originalBytes, 1, context.getSpecialStream()) == 0) {
|
|
||||||
throw new ND4JIllegalStateException("memcpyAsync failed");
|
throw new ND4JIllegalStateException("memcpyAsync failed");
|
||||||
}
|
}
|
||||||
flowController.commitTransfer(context.getSpecialStream());
|
flowController.commitTransfer(context.getSpecialStream());
|
||||||
|
@ -202,8 +201,7 @@ public class ProtectedCudaConstantHandler implements ConstantHandler {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyConstantAsync(currentOffset, point.getPointers().getHostPointer(), originalBytes, 1,
|
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyConstantAsync(currentOffset, point.getHostPointer(), originalBytes, 1, context.getSpecialStream());
|
||||||
context.getSpecialStream());
|
|
||||||
flowController.commitTransfer(context.getSpecialStream());
|
flowController.commitTransfer(context.getSpecialStream());
|
||||||
|
|
||||||
long cAddr = deviceAddresses.get(deviceId).address() + currentOffset;
|
long cAddr = deviceAddresses.get(deviceId).address() + currentOffset;
|
||||||
|
@ -212,7 +210,10 @@ public class ProtectedCudaConstantHandler implements ConstantHandler {
|
||||||
// logger.info("copying to constant: {}, bufferLength: {}, bufferDtype: {}, currentOffset: {}, currentAddres: {}", requiredMemoryBytes, dataBuffer.length(), dataBuffer.dataType(), currentOffset, cAddr);
|
// logger.info("copying to constant: {}, bufferLength: {}, bufferDtype: {}, currentOffset: {}, currentAddres: {}", requiredMemoryBytes, dataBuffer.length(), dataBuffer.dataType(), currentOffset, cAddr);
|
||||||
|
|
||||||
point.setAllocationStatus(AllocationStatus.CONSTANT);
|
point.setAllocationStatus(AllocationStatus.CONSTANT);
|
||||||
point.getPointers().setDevicePointer(new CudaPointer(cAddr));
|
//point.setDevicePointer(new CudaPointer(cAddr));
|
||||||
|
if (1 > 0)
|
||||||
|
throw new UnsupportedOperationException("Pew-pew");
|
||||||
|
|
||||||
point.setConstant(true);
|
point.setConstant(true);
|
||||||
point.tickDeviceWrite();
|
point.tickDeviceWrite();
|
||||||
point.setDeviceId(deviceId);
|
point.setDeviceId(deviceId);
|
||||||
|
|
|
@ -32,6 +32,7 @@ import org.nd4j.jita.conf.Configuration;
|
||||||
import org.nd4j.jita.conf.CudaEnvironment;
|
import org.nd4j.jita.conf.CudaEnvironment;
|
||||||
import org.nd4j.jita.flow.FlowController;
|
import org.nd4j.jita.flow.FlowController;
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
|
import org.nd4j.linalg.api.concurrency.AffinityManager;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||||
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
|
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
|
||||||
|
@ -70,53 +71,12 @@ public class SynchronousFlowController implements FlowController {
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public void synchronizeToHost(AllocationPoint point) {
|
public void synchronizeToHost(AllocationPoint point) {
|
||||||
|
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToPrimary(point.getPtrDataBuffer());
|
||||||
if (!point.isActualOnHostSide()) {
|
|
||||||
val context = allocator.getDeviceContext();
|
|
||||||
|
|
||||||
if (!point.isConstant())
|
|
||||||
waitTillFinished(point);
|
|
||||||
|
|
||||||
// if this piece of memory is device-dependant, we'll also issue copyback once
|
|
||||||
if (point.getAllocationStatus() == AllocationStatus.DEVICE && !point.isActualOnHostSide()) {
|
|
||||||
long perfD = PerformanceTracker.getInstance().helperStartTransaction();
|
|
||||||
val bytes = AllocationUtils.getRequiredMemory(point.getShape());
|
|
||||||
|
|
||||||
if (nativeOps.memcpyAsync(point.getHostPointer(), point.getDevicePointer(), bytes, CudaConstants.cudaMemcpyDeviceToHost, context.getSpecialStream()) == 0)
|
|
||||||
throw new IllegalStateException("synchronizeToHost memcpyAsync failed: " + point.getShape());
|
|
||||||
|
|
||||||
commitTransfer(context.getSpecialStream());
|
|
||||||
|
|
||||||
PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), perfD, point.getNumberOfBytes(), MemcpyDirection.DEVICE_TO_HOST);
|
|
||||||
}
|
|
||||||
|
|
||||||
// updating host read timer
|
|
||||||
point.tickHostRead();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void synchronizeToDevice(@NonNull AllocationPoint point) {
|
public void synchronizeToDevice(@NonNull AllocationPoint point) {
|
||||||
if (point.isConstant())
|
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToSpecial(point.getPtrDataBuffer());
|
||||||
return;
|
|
||||||
|
|
||||||
if (!point.isActualOnDeviceSide()) {
|
|
||||||
if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
|
|
||||||
val context = allocator.getDeviceContext();
|
|
||||||
|
|
||||||
long perfD = PerformanceTracker.getInstance().helperStartTransaction();
|
|
||||||
|
|
||||||
if (nativeOps.memcpyAsync(point.getDevicePointer(), point.getHostPointer(),
|
|
||||||
AllocationUtils.getRequiredMemory(point.getShape()),
|
|
||||||
CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream()) == 0)
|
|
||||||
throw new IllegalStateException("MemcpyAsync failed: " + point.getShape());
|
|
||||||
|
|
||||||
commitTransfer(context.getSpecialStream());
|
|
||||||
point.tickDeviceRead();
|
|
||||||
|
|
||||||
PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), perfD, point.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -147,7 +107,6 @@ public class SynchronousFlowController implements FlowController {
|
||||||
val pointData = allocator.getAllocationPoint(operand);
|
val pointData = allocator.getAllocationPoint(operand);
|
||||||
val pointShape = allocator.getAllocationPoint(operand.shapeInfoDataBuffer());
|
val pointShape = allocator.getAllocationPoint(operand.shapeInfoDataBuffer());
|
||||||
|
|
||||||
pointData.acquireLock();
|
|
||||||
|
|
||||||
if (pointData.getDeviceId() != cId && pointData.getDeviceId() >= 0) {
|
if (pointData.getDeviceId() != cId && pointData.getDeviceId() >= 0) {
|
||||||
DataBuffer buffer = operand.data().originalDataBuffer() == null ? operand.data()
|
DataBuffer buffer = operand.data().originalDataBuffer() == null ? operand.data()
|
||||||
|
@ -172,15 +131,12 @@ public class SynchronousFlowController implements FlowController {
|
||||||
val cId = allocator.getDeviceId();
|
val cId = allocator.getDeviceId();
|
||||||
|
|
||||||
|
|
||||||
if (result != null && !result.isEmpty() && !result.isS()) {
|
if (result != null && !result.isEmpty()) {
|
||||||
Nd4j.getCompressor().autoDecompress(result);
|
Nd4j.getCompressor().autoDecompress(result);
|
||||||
prepareDelayedMemory(result);
|
prepareDelayedMemory(result);
|
||||||
val pointData = allocator.getAllocationPoint(result);
|
val pointData = allocator.getAllocationPoint(result);
|
||||||
val pointShape = allocator.getAllocationPoint(result.shapeInfoDataBuffer());
|
val pointShape = allocator.getAllocationPoint(result.shapeInfoDataBuffer());
|
||||||
|
|
||||||
pointData.acquireLock();
|
|
||||||
|
|
||||||
|
|
||||||
if (pointData.getDeviceId() != cId && pointData.getDeviceId() >= 0 && (!CudaEnvironment.getInstance().getConfiguration().isCrossDeviceAccessAllowed() || !NativeOpsHolder.getInstance().getDeviceNativeOps().isP2PAvailable())) {
|
if (pointData.getDeviceId() != cId && pointData.getDeviceId() >= 0 && (!CudaEnvironment.getInstance().getConfiguration().isCrossDeviceAccessAllowed() || !NativeOpsHolder.getInstance().getDeviceNativeOps().isP2PAvailable())) {
|
||||||
DataBuffer buffer = result.data().originalDataBuffer() == null ? result.data()
|
DataBuffer buffer = result.data().originalDataBuffer() == null ? result.data()
|
||||||
: result.data().originalDataBuffer();
|
: result.data().originalDataBuffer();
|
||||||
|
@ -206,8 +162,7 @@ public class SynchronousFlowController implements FlowController {
|
||||||
|
|
||||||
val pointData = allocator.getAllocationPoint(operand);
|
val pointData = allocator.getAllocationPoint(operand);
|
||||||
val pointShape = allocator.getAllocationPoint(operand.shapeInfoDataBuffer());
|
val pointShape = allocator.getAllocationPoint(operand.shapeInfoDataBuffer());
|
||||||
|
Nd4j.getAffinityManager().ensureLocation(operand, AffinityManager.Location.DEVICE);
|
||||||
pointData.acquireLock();
|
|
||||||
|
|
||||||
if (pointData.getDeviceId() != cId && pointData.getDeviceId() >= 0 && (!CudaEnvironment.getInstance().getConfiguration().isCrossDeviceAccessAllowed() || !NativeOpsHolder.getInstance().getDeviceNativeOps().isP2PAvailable())) {
|
if (pointData.getDeviceId() != cId && pointData.getDeviceId() >= 0 && (!CudaEnvironment.getInstance().getConfiguration().isCrossDeviceAccessAllowed() || !NativeOpsHolder.getInstance().getDeviceNativeOps().isP2PAvailable())) {
|
||||||
DataBuffer buffer = operand.data().originalDataBuffer() == null ? operand.data()
|
DataBuffer buffer = operand.data().originalDataBuffer() == null ? operand.data()
|
||||||
|
@ -240,14 +195,12 @@ public class SynchronousFlowController implements FlowController {
|
||||||
eventsProvider.storeEvent(result.getLastWriteEvent());
|
eventsProvider.storeEvent(result.getLastWriteEvent());
|
||||||
result.setLastWriteEvent(eventsProvider.getEvent());
|
result.setLastWriteEvent(eventsProvider.getEvent());
|
||||||
result.getLastWriteEvent().register(context.getOldStream());
|
result.getLastWriteEvent().register(context.getOldStream());
|
||||||
result.releaseLock();
|
|
||||||
|
|
||||||
|
|
||||||
for (AllocationPoint operand : operands) {
|
for (AllocationPoint operand : operands) {
|
||||||
eventsProvider.storeEvent(operand.getLastReadEvent());
|
eventsProvider.storeEvent(operand.getLastReadEvent());
|
||||||
operand.setLastReadEvent(eventsProvider.getEvent());
|
operand.setLastReadEvent(eventsProvider.getEvent());
|
||||||
operand.getLastReadEvent().register(context.getOldStream());
|
operand.getLastReadEvent().register(context.getOldStream());
|
||||||
operand.releaseLock();
|
|
||||||
}
|
}
|
||||||
// context.syncOldStream();
|
// context.syncOldStream();
|
||||||
}
|
}
|
||||||
|
@ -263,7 +216,6 @@ public class SynchronousFlowController implements FlowController {
|
||||||
eventsProvider.storeEvent(pointOperand.getLastWriteEvent());
|
eventsProvider.storeEvent(pointOperand.getLastWriteEvent());
|
||||||
pointOperand.setLastWriteEvent(eventsProvider.getEvent());
|
pointOperand.setLastWriteEvent(eventsProvider.getEvent());
|
||||||
pointOperand.getLastWriteEvent().register(context.getOldStream());
|
pointOperand.getLastWriteEvent().register(context.getOldStream());
|
||||||
pointOperand.releaseLock();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -276,14 +228,12 @@ public class SynchronousFlowController implements FlowController {
|
||||||
eventsProvider.storeEvent(point.getLastWriteEvent());
|
eventsProvider.storeEvent(point.getLastWriteEvent());
|
||||||
point.setLastWriteEvent(eventsProvider.getEvent());
|
point.setLastWriteEvent(eventsProvider.getEvent());
|
||||||
point.getLastWriteEvent().register(context.getOldStream());
|
point.getLastWriteEvent().register(context.getOldStream());
|
||||||
point.releaseLock();
|
|
||||||
|
|
||||||
for (INDArray operand : operands) {
|
for (INDArray operand : operands) {
|
||||||
if (operand == null || operand.isEmpty())
|
if (operand == null || operand.isEmpty())
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
val pointOperand = allocator.getAllocationPoint(operand);
|
val pointOperand = allocator.getAllocationPoint(operand);
|
||||||
pointOperand.releaseLock();
|
|
||||||
eventsProvider.storeEvent(pointOperand.getLastReadEvent());
|
eventsProvider.storeEvent(pointOperand.getLastReadEvent());
|
||||||
pointOperand.setLastReadEvent(eventsProvider.getEvent());
|
pointOperand.setLastReadEvent(eventsProvider.getEvent());
|
||||||
pointOperand.getLastReadEvent().register(context.getOldStream());
|
pointOperand.getLastReadEvent().register(context.getOldStream());
|
||||||
|
@ -295,7 +245,6 @@ public class SynchronousFlowController implements FlowController {
|
||||||
val context = allocator.getDeviceContext();
|
val context = allocator.getDeviceContext();
|
||||||
|
|
||||||
if (result != null) {
|
if (result != null) {
|
||||||
result.acquireLock();
|
|
||||||
result.setCurrentContext(context);
|
result.setCurrentContext(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -303,7 +252,6 @@ public class SynchronousFlowController implements FlowController {
|
||||||
if (operand == null)
|
if (operand == null)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
operand.acquireLock();
|
|
||||||
operand.setCurrentContext(context);
|
operand.setCurrentContext(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.jita.handler.impl;
|
package org.nd4j.jita.handler.impl;
|
||||||
|
|
||||||
|
import lombok.var;
|
||||||
import org.nd4j.nativeblas.OpaqueLaunchContext;
|
import org.nd4j.nativeblas.OpaqueLaunchContext;
|
||||||
import org.nd4j.shade.guava.collect.HashBasedTable;
|
import org.nd4j.shade.guava.collect.HashBasedTable;
|
||||||
import org.nd4j.shade.guava.collect.Table;
|
import org.nd4j.shade.guava.collect.Table;
|
||||||
|
@ -44,9 +45,6 @@ import org.nd4j.jita.flow.FlowController;
|
||||||
import org.nd4j.jita.flow.impl.GridFlowController;
|
import org.nd4j.jita.flow.impl.GridFlowController;
|
||||||
import org.nd4j.jita.handler.MemoryHandler;
|
import org.nd4j.jita.handler.MemoryHandler;
|
||||||
import org.nd4j.jita.memory.MemoryProvider;
|
import org.nd4j.jita.memory.MemoryProvider;
|
||||||
import org.nd4j.jita.memory.impl.CudaCachingZeroProvider;
|
|
||||||
import org.nd4j.jita.memory.impl.CudaDirectProvider;
|
|
||||||
import org.nd4j.jita.memory.impl.CudaFullCachingProvider;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
import org.nd4j.linalg.api.concurrency.AffinityManager;
|
import org.nd4j.linalg.api.concurrency.AffinityManager;
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
|
@ -99,9 +97,6 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
|
|
||||||
private final AtomicBoolean wasInitialised = new AtomicBoolean(false);
|
private final AtomicBoolean wasInitialised = new AtomicBoolean(false);
|
||||||
|
|
||||||
@Getter
|
|
||||||
private final MemoryProvider memoryProvider;
|
|
||||||
|
|
||||||
private final FlowController flowController;
|
private final FlowController flowController;
|
||||||
|
|
||||||
private final AllocationStatus INITIAL_LOCATION;
|
private final AllocationStatus INITIAL_LOCATION;
|
||||||
|
@ -148,20 +143,6 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
throw new RuntimeException("Unknown ExecutionModel: [" + configuration.getExecutionModel() + "]");
|
throw new RuntimeException("Unknown ExecutionModel: [" + configuration.getExecutionModel() + "]");
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (configuration.getAllocationModel()) {
|
|
||||||
case CACHE_ALL:
|
|
||||||
this.memoryProvider = new CudaFullCachingProvider();
|
|
||||||
break;
|
|
||||||
case CACHE_HOST:
|
|
||||||
this.memoryProvider = new CudaCachingZeroProvider();
|
|
||||||
break;
|
|
||||||
case DIRECT:
|
|
||||||
this.memoryProvider = new CudaDirectProvider();
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw new RuntimeException("Unknown AllocationModel: [" + configuration.getAllocationModel() + "]");
|
|
||||||
}
|
|
||||||
|
|
||||||
int numDevices = NativeOpsHolder.getInstance().getDeviceNativeOps().getAvailableDevices();
|
int numDevices = NativeOpsHolder.getInstance().getDeviceNativeOps().getAvailableDevices();
|
||||||
for (int i = 0; i < numDevices; i++) {
|
for (int i = 0; i < numDevices; i++) {
|
||||||
deviceAllocations.add(new ConcurrentHashMap<Long, Long>());
|
deviceAllocations.add(new ConcurrentHashMap<Long, Long>());
|
||||||
|
@ -191,7 +172,7 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
int numBuckets = configuration.getNumberOfGcThreads();
|
int numBuckets = configuration.getNumberOfGcThreads();
|
||||||
long bucketId = RandomUtils.nextInt(0, numBuckets);
|
long bucketId = RandomUtils.nextInt(0, numBuckets);
|
||||||
|
|
||||||
long reqMemory = AllocationUtils.getRequiredMemory(point.getShape());
|
long reqMemory = point.getNumberOfBytes();
|
||||||
|
|
||||||
zeroUseCounter.addAndGet(reqMemory);
|
zeroUseCounter.addAndGet(reqMemory);
|
||||||
|
|
||||||
|
@ -221,130 +202,7 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
public PointersPair alloc(AllocationStatus targetMode, AllocationPoint point, AllocationShape shape,
|
public PointersPair alloc(AllocationStatus targetMode, AllocationPoint point, AllocationShape shape,
|
||||||
boolean initialize) {
|
boolean initialize) {
|
||||||
|
|
||||||
long reqMemory = AllocationUtils.getRequiredMemory(shape);
|
throw new UnsupportedOperationException();
|
||||||
val context = getCudaContext();
|
|
||||||
switch (targetMode) {
|
|
||||||
case HOST: {
|
|
||||||
if (MemoryTracker.getInstance().getActiveHostAmount() + reqMemory >= configuration.getMaximumZeroAllocation()) {
|
|
||||||
|
|
||||||
while (MemoryTracker.getInstance().getActiveHostAmount() + reqMemory >= configuration.getMaximumZeroAllocation()) {
|
|
||||||
|
|
||||||
val before = MemoryTracker.getInstance().getActiveHostAmount();
|
|
||||||
memoryProvider.purgeCache();
|
|
||||||
Nd4j.getMemoryManager().invokeGc();
|
|
||||||
val after = MemoryTracker.getInstance().getActiveHostAmount();
|
|
||||||
|
|
||||||
log.debug("[HOST] before: {}; after: {};", before, after);
|
|
||||||
|
|
||||||
if (MemoryTracker.getInstance().getActiveHostAmount() + reqMemory >= configuration.getMaximumZeroAllocation()) {
|
|
||||||
try {
|
|
||||||
log.warn("No available [HOST] memory, sleeping for a while... Consider increasing -Xmx next time.");
|
|
||||||
log.debug("Currently used: [" + zeroUseCounter.get() + "], allocated objects: [" + zeroAllocations.get(0) + "]");
|
|
||||||
|
|
||||||
memoryProvider.purgeCache();
|
|
||||||
Nd4j.getMemoryManager().invokeGc();
|
|
||||||
Thread.sleep(1000);
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
PointersPair pair = memoryProvider.malloc(shape, point, targetMode);
|
|
||||||
|
|
||||||
if (initialize) {
|
|
||||||
org.bytedeco.javacpp.Pointer.memset(pair.getHostPointer(), 0, reqMemory);
|
|
||||||
point.tickHostWrite();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
pickupHostAllocation(point);
|
|
||||||
|
|
||||||
return pair;
|
|
||||||
}
|
|
||||||
case DEVICE: {
|
|
||||||
int deviceId = getDeviceId();
|
|
||||||
|
|
||||||
PointersPair returnPair = new PointersPair();
|
|
||||||
PointersPair tmpPair = new PointersPair();
|
|
||||||
|
|
||||||
if (point.getPointers() == null)
|
|
||||||
point.setPointers(tmpPair);
|
|
||||||
|
|
||||||
if (deviceMemoryTracker.reserveAllocationIfPossible(Thread.currentThread().getId(), deviceId, reqMemory)) {
|
|
||||||
point.setDeviceId(deviceId);
|
|
||||||
val pair = memoryProvider.malloc(shape, point, targetMode);
|
|
||||||
if (pair != null) {
|
|
||||||
returnPair.setDevicePointer(pair.getDevicePointer());
|
|
||||||
|
|
||||||
point.setAllocationStatus(AllocationStatus.DEVICE);
|
|
||||||
|
|
||||||
if (point.getPointers() == null)
|
|
||||||
throw new RuntimeException("PointersPair can't be null");
|
|
||||||
|
|
||||||
point.getPointers().setDevicePointer(pair.getDevicePointer());
|
|
||||||
|
|
||||||
deviceAllocations.get(deviceId).put(point.getObjectId(), point.getObjectId());
|
|
||||||
|
|
||||||
|
|
||||||
val p = point.getBucketId();
|
|
||||||
|
|
||||||
if (p != null) {
|
|
||||||
val m = zeroAllocations.get(point.getBucketId());
|
|
||||||
|
|
||||||
// m can be null, if that's point from workspace - just no bucketId for it
|
|
||||||
if (m != null)
|
|
||||||
m.remove(point.getObjectId());
|
|
||||||
}
|
|
||||||
|
|
||||||
deviceMemoryTracker.addToAllocation(Thread.currentThread().getId(), deviceId, reqMemory);
|
|
||||||
|
|
||||||
if (!initialize) {
|
|
||||||
point.tickDeviceWrite();
|
|
||||||
} else {
|
|
||||||
nativeOps.memsetAsync(pair.getDevicePointer(), 0, reqMemory, 0, context.getSpecialStream());
|
|
||||||
context.getSpecialStream().synchronize();
|
|
||||||
|
|
||||||
point.tickDeviceWrite();
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.warn("Out of [DEVICE] memory, host memory will be used instead: deviceId: [{}], requested bytes: [{}]; Approximate free bytes: {}; Real free bytes: {}", deviceId, reqMemory, MemoryTracker.getInstance().getApproximateFreeMemory(deviceId), MemoryTracker.getInstance().getPreciseFreeMemory(deviceId));
|
|
||||||
log.info("Total allocated dev_0: {}", MemoryTracker.getInstance().getActiveMemory(0));
|
|
||||||
log.info("Cached dev_0: {}", MemoryTracker.getInstance().getCachedAmount(0));
|
|
||||||
log.info("Allocated dev_0: {}", MemoryTracker.getInstance().getAllocatedAmount(0));
|
|
||||||
log.info("Workspace dev_0: {}", MemoryTracker.getInstance().getWorkspaceAllocatedAmount(0));
|
|
||||||
//log.info("Total allocated dev_1: {}", MemoryTracker.getInstance().getActiveMemory(1));
|
|
||||||
// if device memory allocation failed (aka returned NULL), keep using host memory instead
|
|
||||||
|
|
||||||
returnPair.setDevicePointer(tmpPair.getHostPointer());
|
|
||||||
|
|
||||||
point.setAllocationStatus(AllocationStatus.HOST);
|
|
||||||
|
|
||||||
Nd4j.getMemoryManager().invokeGc();
|
|
||||||
try {
|
|
||||||
Thread.sleep(100);
|
|
||||||
} catch (Exception e) {
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.warn("Hard limit on [DEVICE] memory hit, please consider tuning memory parameters, deviceId [{}]",
|
|
||||||
deviceId);
|
|
||||||
|
|
||||||
Nd4j.getMemoryManager().invokeGc();
|
|
||||||
try {
|
|
||||||
Thread.sleep(100);
|
|
||||||
} catch (InterruptedException e) {
|
|
||||||
//
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return returnPair;
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
throw new IllegalStateException("Can't allocate memory on target [" + targetMode + "]");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -356,7 +214,7 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public boolean pingDeviceForFreeMemory(Integer deviceId, long requiredMemory) {
|
public boolean pingDeviceForFreeMemory(Integer deviceId, long requiredMemory) {
|
||||||
return memoryProvider.pingDeviceForFreeMemory(deviceId, requiredMemory);
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -371,47 +229,7 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
@Override
|
@Override
|
||||||
public void relocate(AllocationStatus currentStatus, AllocationStatus targetStatus, AllocationPoint point,
|
public void relocate(AllocationStatus currentStatus, AllocationStatus targetStatus, AllocationPoint point,
|
||||||
AllocationShape shape, CudaContext context) {
|
AllocationShape shape, CudaContext context) {
|
||||||
//log.info("RELOCATE CALLED: [" +currentStatus+ "] -> ["+targetStatus+"]");
|
|
||||||
|
|
||||||
if (currentStatus == AllocationStatus.DEVICE && targetStatus == AllocationStatus.HOST) {
|
|
||||||
// DEVICE -> HOST
|
|
||||||
DataBuffer targetBuffer = point.getBuffer();
|
|
||||||
if (targetBuffer == null)
|
|
||||||
throw new IllegalStateException("Target buffer is NULL!");
|
|
||||||
|
|
||||||
Pointer devicePointer = new CudaPointer(point.getPointers().getDevicePointer().address());
|
|
||||||
|
|
||||||
} else if (currentStatus == AllocationStatus.HOST && targetStatus == AllocationStatus.DEVICE) {
|
|
||||||
// HOST -> DEVICE
|
|
||||||
|
|
||||||
|
|
||||||
// TODO: this probably should be removed
|
|
||||||
if (point.isConstant()) {
|
|
||||||
//log.info("Skipping relocation for constant");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (point.getPointers().getDevicePointer() == null) {
|
|
||||||
throw new IllegalStateException("devicePointer is NULL!");
|
|
||||||
}
|
|
||||||
|
|
||||||
val profD = PerformanceTracker.getInstance().helperStartTransaction();
|
|
||||||
|
|
||||||
if (nativeOps.memcpyAsync(point.getPointers().getDevicePointer(), point.getPointers().getHostPointer(),
|
|
||||||
AllocationUtils.getRequiredMemory(shape), CudaConstants.cudaMemcpyHostToDevice,
|
|
||||||
context.getSpecialStream()) == 0)
|
|
||||||
throw new IllegalStateException("MemcpyAsync relocate H2D failed: [" + point.getHostPointer().address()
|
|
||||||
+ "] -> [" + point.getDevicePointer().address() + "]");
|
|
||||||
|
|
||||||
flowController.commitTransfer(context.getSpecialStream());
|
|
||||||
|
|
||||||
PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), profD, point.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);
|
|
||||||
|
|
||||||
//context.syncOldStream();
|
|
||||||
|
|
||||||
} else
|
|
||||||
throw new UnsupportedOperationException("Can't relocate data in requested direction: [" + currentStatus
|
|
||||||
+ "] -> [" + targetStatus + "]");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -440,11 +258,6 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
@Override
|
@Override
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public void copyforward(AllocationPoint point, AllocationShape shape) {
|
public void copyforward(AllocationPoint point, AllocationShape shape) {
|
||||||
/*
|
|
||||||
Technically that's just a case for relocate, with source as HOST and target point.getAllocationStatus()
|
|
||||||
*/
|
|
||||||
log.info("copyforward() called on tp[" + point.getObjectId() + "], shape: " + point.getShape());
|
|
||||||
//relocate(AllocationStatus.HOST, point.getAllocationStatus(), point, shape);
|
|
||||||
throw new UnsupportedOperationException("Deprecated call");
|
throw new UnsupportedOperationException("Deprecated call");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -467,15 +280,7 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public void free(AllocationPoint point, AllocationStatus target) {
|
public void free(AllocationPoint point, AllocationStatus target) {
|
||||||
//if (point.getAllocationStatus() == AllocationStatus.DEVICE)
|
|
||||||
//deviceAllocations.get(point.getDeviceId()).remove(point.getObjectId());
|
|
||||||
|
|
||||||
//zeroAllocations.get(point.getBucketId()).remove(point.getObjectId());
|
|
||||||
if (point.getAllocationStatus() == AllocationStatus.DEVICE)
|
|
||||||
deviceMemoryTracker.subFromAllocation(Thread.currentThread().getId(), point.getDeviceId(),
|
|
||||||
AllocationUtils.getRequiredMemory(point.getShape()));
|
|
||||||
|
|
||||||
memoryProvider.free(point);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -525,7 +330,7 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
CudaContext tContext = null;
|
CudaContext tContext = null;
|
||||||
|
|
||||||
if (dstBuffer.isConstant()) {
|
if (dstBuffer.isConstant()) {
|
||||||
org.bytedeco.javacpp.Pointer dstPointer = new CudaPointer(point.getPointers().getHostPointer().address() + dstOffset, 0L);
|
org.bytedeco.javacpp.Pointer dstPointer = new CudaPointer(point.getHostPointer().address() + dstOffset, 0L);
|
||||||
org.bytedeco.javacpp.Pointer srcPointerJ = new CudaPointer(srcPointer, length);
|
org.bytedeco.javacpp.Pointer srcPointerJ = new CudaPointer(srcPointer, length);
|
||||||
|
|
||||||
val profD = PerformanceTracker.getInstance().helperStartTransaction();
|
val profD = PerformanceTracker.getInstance().helperStartTransaction();
|
||||||
|
@ -534,14 +339,34 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
|
|
||||||
point.tickHostRead();
|
point.tickHostRead();
|
||||||
} else {
|
} else {
|
||||||
|
// if we're copying something into host memory, but we're on device - we need to provide exact copy to device as well
|
||||||
|
Pointer rDP = new CudaPointer(point.getDevicePointer().address() + dstOffset);
|
||||||
|
|
||||||
|
if (tContext == null)
|
||||||
|
tContext = flowController.prepareAction(point);
|
||||||
|
|
||||||
|
var prof = PerformanceTracker.getInstance().helperStartTransaction();
|
||||||
|
|
||||||
|
flowController.commitTransfer(tContext.getSpecialStream());
|
||||||
|
|
||||||
|
if (nativeOps.memcpyAsync(rDP, srcPointer, length, CudaConstants.cudaMemcpyHostToDevice, tContext.getSpecialStream()) == 0)
|
||||||
|
throw new IllegalStateException("MemcpyAsync H2D failed: [" + srcPointer.address() + "] -> [" + rDP.address() + "]");
|
||||||
|
|
||||||
|
flowController.commitTransfer(tContext.getSpecialStream());
|
||||||
|
|
||||||
|
PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), prof, point.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);
|
||||||
|
|
||||||
|
flowController.registerAction(tContext, point);
|
||||||
|
point.tickDeviceWrite();
|
||||||
|
|
||||||
// we optionally copy to host memory
|
// we optionally copy to host memory
|
||||||
if (point.getPointers().getHostPointer() != null) {
|
if (point.getHostPointer() != null) {
|
||||||
Pointer dP = new CudaPointer((point.getPointers().getHostPointer().address()) + dstOffset);
|
Pointer dP = new CudaPointer((point.getHostPointer().address()) + dstOffset);
|
||||||
|
|
||||||
CudaContext context = flowController.prepareAction(point);
|
CudaContext context = flowController.prepareAction(point);
|
||||||
tContext = context;
|
tContext = context;
|
||||||
|
|
||||||
val prof = PerformanceTracker.getInstance().helperStartTransaction();
|
prof = PerformanceTracker.getInstance().helperStartTransaction();
|
||||||
|
|
||||||
if (nativeOps.memcpyAsync(dP, srcPointer, length, CudaConstants.cudaMemcpyHostToHost, context.getSpecialStream()) == 0)
|
if (nativeOps.memcpyAsync(dP, srcPointer, length, CudaConstants.cudaMemcpyHostToHost, context.getSpecialStream()) == 0)
|
||||||
throw new IllegalStateException("MemcpyAsync H2H failed: [" + srcPointer.address() + "] -> [" + dP.address() + "]");
|
throw new IllegalStateException("MemcpyAsync H2H failed: [" + srcPointer.address() + "] -> [" + dP.address() + "]");
|
||||||
|
@ -552,28 +377,10 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
|
|
||||||
if (point.getAllocationStatus() == AllocationStatus.HOST)
|
if (point.getAllocationStatus() == AllocationStatus.HOST)
|
||||||
flowController.registerAction(context, point);
|
flowController.registerAction(context, point);
|
||||||
|
|
||||||
|
point.tickHostRead();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// if we're copying something into host memory, but we're on device - we need to provide exact copy to device as well
|
|
||||||
if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
|
|
||||||
Pointer rDP = new CudaPointer(point.getPointers().getDevicePointer().address() + dstOffset);
|
|
||||||
|
|
||||||
if (tContext == null)
|
|
||||||
tContext = flowController.prepareAction(point);
|
|
||||||
|
|
||||||
val prof = PerformanceTracker.getInstance().helperStartTransaction();
|
|
||||||
|
|
||||||
if (nativeOps.memcpyAsync(rDP, srcPointer, length, CudaConstants.cudaMemcpyHostToDevice, tContext.getSpecialStream()) == 0)
|
|
||||||
throw new IllegalStateException("MemcpyAsync H2D failed: [" + srcPointer.address() + "] -> [" + rDP.address() + "]");
|
|
||||||
|
|
||||||
flowController.commitTransfer(tContext.getSpecialStream());
|
|
||||||
|
|
||||||
PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), prof, point.getNumberOfBytes(),MemcpyDirection.HOST_TO_DEVICE);
|
|
||||||
|
|
||||||
flowController.registerAction(tContext, point);
|
|
||||||
point.tickDeviceWrite();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -581,7 +388,7 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
CudaContext context) {
|
CudaContext context) {
|
||||||
AllocationPoint point = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint();
|
AllocationPoint point = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint();
|
||||||
|
|
||||||
Pointer dP = new CudaPointer((point.getPointers().getDevicePointer().address()) + dstOffset);
|
Pointer dP = new CudaPointer((point.getDevicePointer().address()) + dstOffset);
|
||||||
|
|
||||||
if (nativeOps.memcpyAsync(dP, srcPointer, length, CudaConstants.cudaMemcpyDeviceToDevice, context.getOldStream()) == 0)
|
if (nativeOps.memcpyAsync(dP, srcPointer, length, CudaConstants.cudaMemcpyDeviceToDevice, context.getOldStream()) == 0)
|
||||||
throw new ND4JIllegalStateException("memcpyAsync failed");
|
throw new ND4JIllegalStateException("memcpyAsync failed");
|
||||||
|
@ -604,7 +411,7 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
CudaContext context = getCudaContext();
|
CudaContext context = getCudaContext();
|
||||||
AllocationPoint point = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint();
|
AllocationPoint point = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint();
|
||||||
|
|
||||||
Pointer dP = new CudaPointer((point.getPointers().getHostPointer().address()) + dstOffset);
|
Pointer dP = new CudaPointer((point.getHostPointer().address()) + dstOffset);
|
||||||
|
|
||||||
val profH = PerformanceTracker.getInstance().helperStartTransaction();
|
val profH = PerformanceTracker.getInstance().helperStartTransaction();
|
||||||
|
|
||||||
|
@ -614,7 +421,7 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), profH, point.getNumberOfBytes(),MemcpyDirection.HOST_TO_HOST);
|
PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), profH, point.getNumberOfBytes(),MemcpyDirection.HOST_TO_HOST);
|
||||||
|
|
||||||
if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
|
if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
|
||||||
Pointer rDP = new CudaPointer(point.getPointers().getDevicePointer().address() + dstOffset);
|
Pointer rDP = new CudaPointer(point.getDevicePointer().address() + dstOffset);
|
||||||
|
|
||||||
val profD = PerformanceTracker.getInstance().helperStartTransaction();
|
val profD = PerformanceTracker.getInstance().helperStartTransaction();
|
||||||
|
|
||||||
|
@ -717,23 +524,22 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
@Override
|
@Override
|
||||||
public org.bytedeco.javacpp.Pointer getDevicePointer(DataBuffer buffer, CudaContext context) {
|
public org.bytedeco.javacpp.Pointer getDevicePointer(DataBuffer buffer, CudaContext context) {
|
||||||
// TODO: It would be awesome to get rid of typecasting here
|
// TODO: It would be awesome to get rid of typecasting here
|
||||||
//getCudaContext().syncOldStream();
|
|
||||||
AllocationPoint dstPoint = ((BaseCudaDataBuffer) buffer).getAllocationPoint();
|
AllocationPoint dstPoint = ((BaseCudaDataBuffer) buffer).getAllocationPoint();
|
||||||
|
|
||||||
// if that's device state, we probably might want to update device memory state
|
// if that's device state, we probably might want to update device memory state
|
||||||
if (dstPoint.getAllocationStatus() == AllocationStatus.DEVICE) {
|
if (dstPoint.getAllocationStatus() == AllocationStatus.DEVICE) {
|
||||||
if (!dstPoint.isActualOnDeviceSide()) {
|
if (!dstPoint.isActualOnDeviceSide()) {
|
||||||
// log.info("Relocating to GPU");
|
//relocate(AllocationStatus.HOST, AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), context);
|
||||||
relocate(AllocationStatus.HOST, AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), context);
|
throw new UnsupportedOperationException("Pew-pew");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// we update memory use counter, to announce that it's somehow used on device
|
if (dstPoint.getDevicePointer() == null)
|
||||||
dstPoint.tickDeviceRead();
|
return null;
|
||||||
|
|
||||||
// return pointer with offset if needed. length is specified for constructor compatibility purposes
|
|
||||||
val p = new CudaPointer(dstPoint.getPointers().getDevicePointer(), buffer.length(),
|
// return pointer. length is specified for constructor compatibility purposes. Offset is accounted at C++ side
|
||||||
(buffer.offset() * buffer.getElementSize()));
|
val p = new CudaPointer(dstPoint.getDevicePointer(), buffer.length(), 0);
|
||||||
|
|
||||||
if (OpProfiler.getInstance().getConfig().isCheckLocality())
|
if (OpProfiler.getInstance().getConfig().isCheckLocality())
|
||||||
NativeOpsHolder.getInstance().getDeviceNativeOps().tryPointer(context.getOldStream(), p, 1);
|
NativeOpsHolder.getInstance().getDeviceNativeOps().tryPointer(context.getOldStream(), p, 1);
|
||||||
|
@ -749,10 +555,17 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
case SHORT:
|
case SHORT:
|
||||||
case UINT16:
|
case UINT16:
|
||||||
case HALF:
|
case HALF:
|
||||||
|
case BFLOAT16:
|
||||||
return p.asShortPointer();
|
return p.asShortPointer();
|
||||||
case UINT64:
|
case UINT64:
|
||||||
case LONG:
|
case LONG:
|
||||||
return p.asLongPointer();
|
return p.asLongPointer();
|
||||||
|
case UTF8:
|
||||||
|
case UBYTE:
|
||||||
|
case BYTE:
|
||||||
|
return p.asBytePointer();
|
||||||
|
case BOOL:
|
||||||
|
return p.asBooleanPointer();
|
||||||
default:
|
default:
|
||||||
return p;
|
return p;
|
||||||
}
|
}
|
||||||
|
@ -769,17 +582,14 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
AllocationPoint dstPoint = ((BaseCudaDataBuffer) buffer).getAllocationPoint();
|
AllocationPoint dstPoint = ((BaseCudaDataBuffer) buffer).getAllocationPoint();
|
||||||
|
|
||||||
// return pointer with offset if needed. length is specified for constructor compatibility purposes
|
// return pointer with offset if needed. length is specified for constructor compatibility purposes
|
||||||
if (dstPoint.getPointers().getHostPointer() == null) {
|
if (dstPoint.getHostPointer() == null) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
//dstPoint.tickHostWrite();
|
|
||||||
//dstPoint.tickHostRead();
|
|
||||||
//log.info("Requesting host pointer for {}", buffer);
|
|
||||||
//getCudaContext().syncOldStream();
|
|
||||||
synchronizeThreadDevice(Thread.currentThread().getId(), dstPoint.getDeviceId(), dstPoint);
|
synchronizeThreadDevice(Thread.currentThread().getId(), dstPoint.getDeviceId(), dstPoint);
|
||||||
|
|
||||||
CudaPointer p = new CudaPointer(dstPoint.getPointers().getHostPointer(), buffer.length(),
|
CudaPointer p = new CudaPointer(dstPoint.getHostPointer(), buffer.length(), 0);
|
||||||
(buffer.offset() * buffer.getElementSize()));
|
|
||||||
switch (buffer.dataType()) {
|
switch (buffer.dataType()) {
|
||||||
case DOUBLE:
|
case DOUBLE:
|
||||||
return p.asDoublePointer();
|
return p.asDoublePointer();
|
||||||
|
@ -805,6 +615,9 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
public synchronized void relocateObject(DataBuffer buffer) {
|
public synchronized void relocateObject(DataBuffer buffer) {
|
||||||
AllocationPoint dstPoint = AtomicAllocator.getInstance().getAllocationPoint(buffer);
|
AllocationPoint dstPoint = AtomicAllocator.getInstance().getAllocationPoint(buffer);
|
||||||
|
|
||||||
|
if (1 > 0)
|
||||||
|
throw new UnsupportedOperationException("Pew-pew");
|
||||||
|
|
||||||
// we don't relocate non-DEVICE buffers (i.e HOST or CONSTANT)
|
// we don't relocate non-DEVICE buffers (i.e HOST or CONSTANT)
|
||||||
if (dstPoint.getAllocationStatus() != AllocationStatus.DEVICE)
|
if (dstPoint.getAllocationStatus() != AllocationStatus.DEVICE)
|
||||||
return;
|
return;
|
||||||
|
@ -838,14 +651,14 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
// if we're out of workspace, we should mark our buffer as detached, so gc will pick it up eventually
|
// if we're out of workspace, we should mark our buffer as detached, so gc will pick it up eventually
|
||||||
// host part is optional
|
// host part is optional
|
||||||
if (dstPoint.getHostPointer() != null) {
|
if (dstPoint.getHostPointer() != null) {
|
||||||
val pairH = alloc(AllocationStatus.HOST, dstPoint, dstPoint.getShape(), false);
|
//val pairH = alloc(AllocationStatus.HOST, dstPoint, dstPoint.getShape(), false);
|
||||||
dstPoint.getPointers().setHostPointer(pairH.getHostPointer());
|
//dstPoint.getPointers().setHostPointer(pairH.getHostPointer());
|
||||||
}
|
}
|
||||||
|
|
||||||
val pairD = alloc(AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), false);
|
//val pairD = alloc(AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), false);
|
||||||
dstPoint.getPointers().setDevicePointer(pairD.getDevicePointer());
|
//dstPoint.getPointers().setDevicePointer(pairD.getDevicePointer());
|
||||||
|
|
||||||
//log.info("New host pointer: {}; Old host pointer: {}", dstPoint.getHostPointer().address(), ohPtr.address());
|
////log.info("New host pointer: {}; Old host pointer: {}", dstPoint.getHostPointer().address(), ohPtr.address());
|
||||||
|
|
||||||
CudaContext context = getCudaContext();
|
CudaContext context = getCudaContext();
|
||||||
|
|
||||||
|
@ -876,10 +689,10 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
|
|
||||||
Nd4j.getMemoryManager().memcpy(nBuffer, buffer);
|
Nd4j.getMemoryManager().memcpy(nBuffer, buffer);
|
||||||
|
|
||||||
dstPoint.getPointers().setDevicePointer(nBuffer.getAllocationPoint().getDevicePointer());
|
//dstPoint.getPointers().setDevicePointer(nBuffer.getAllocationPoint().getDevicePointer());
|
||||||
|
|
||||||
if (dstPoint.getHostPointer() != null) {
|
if (dstPoint.getHostPointer() != null) {
|
||||||
dstPoint.getPointers().setHostPointer(nBuffer.getAllocationPoint().getHostPointer());
|
// dstPoint.getPointers().setHostPointer(nBuffer.getAllocationPoint().getHostPointer());
|
||||||
}
|
}
|
||||||
|
|
||||||
dstPoint.setDeviceId(deviceId);
|
dstPoint.setDeviceId(deviceId);
|
||||||
|
@ -908,11 +721,10 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
context.syncSpecialStream();
|
context.syncSpecialStream();
|
||||||
}
|
}
|
||||||
|
|
||||||
memoryProvider.free(dstPoint);
|
//deviceMemoryTracker.subFromAllocation(Thread.currentThread().getId(), dstPoint.getDeviceId(), AllocationUtils.getRequiredMemory(dstPoint.getShape()));
|
||||||
deviceMemoryTracker.subFromAllocation(Thread.currentThread().getId(), dstPoint.getDeviceId(), AllocationUtils.getRequiredMemory(dstPoint.getShape()));
|
|
||||||
|
|
||||||
// we replace original device pointer with new one
|
// we replace original device pointer with new one
|
||||||
alloc(AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), false);
|
//alloc(AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), false);
|
||||||
|
|
||||||
val profD = PerformanceTracker.getInstance().helperStartTransaction();
|
val profD = PerformanceTracker.getInstance().helperStartTransaction();
|
||||||
|
|
||||||
|
@ -940,6 +752,9 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
public boolean promoteObject(DataBuffer buffer) {
|
public boolean promoteObject(DataBuffer buffer) {
|
||||||
AllocationPoint dstPoint = AtomicAllocator.getInstance().getAllocationPoint(buffer);
|
AllocationPoint dstPoint = AtomicAllocator.getInstance().getAllocationPoint(buffer);
|
||||||
|
|
||||||
|
if (1 > 0)
|
||||||
|
throw new UnsupportedOperationException("Pew-pew");
|
||||||
|
|
||||||
if (dstPoint.getAllocationStatus() != AllocationStatus.HOST)
|
if (dstPoint.getAllocationStatus() != AllocationStatus.HOST)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
|
@ -952,20 +767,19 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
Nd4j.getConstantHandler().moveToConstantSpace(buffer);
|
Nd4j.getConstantHandler().moveToConstantSpace(buffer);
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
PointersPair pair = memoryProvider.malloc(dstPoint.getShape(), dstPoint, AllocationStatus.DEVICE);
|
PointersPair pair = null; //memoryProvider.malloc(dstPoint.getShape(), dstPoint, AllocationStatus.DEVICE);
|
||||||
|
|
||||||
if (pair != null) {
|
if (pair != null) {
|
||||||
Integer deviceId = getDeviceId();
|
Integer deviceId = getDeviceId();
|
||||||
// log.info("Promoting object to device: [{}]", deviceId);
|
// log.info("Promoting object to device: [{}]", deviceId);
|
||||||
|
|
||||||
dstPoint.getPointers().setDevicePointer(pair.getDevicePointer());
|
//dstPoint.setDevicePointer(pair.getDevicePointer());
|
||||||
dstPoint.setAllocationStatus(AllocationStatus.DEVICE);
|
dstPoint.setAllocationStatus(AllocationStatus.DEVICE);
|
||||||
|
|
||||||
deviceAllocations.get(deviceId).put(dstPoint.getObjectId(), dstPoint.getObjectId());
|
deviceAllocations.get(deviceId).put(dstPoint.getObjectId(), dstPoint.getObjectId());
|
||||||
|
|
||||||
zeroAllocations.get(dstPoint.getBucketId()).remove(dstPoint.getObjectId());
|
zeroAllocations.get(dstPoint.getBucketId()).remove(dstPoint.getObjectId());
|
||||||
deviceMemoryTracker.addToAllocation(Thread.currentThread().getId(), deviceId,
|
//deviceMemoryTracker.addToAllocation(Thread.currentThread().getId(), deviceId, AllocationUtils.getRequiredMemory(dstPoint.getShape()));
|
||||||
AllocationUtils.getRequiredMemory(dstPoint.getShape()));
|
|
||||||
|
|
||||||
|
|
||||||
dstPoint.tickHostWrite();
|
dstPoint.tickHostWrite();
|
||||||
|
@ -1103,7 +917,7 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
if (deviceAllocations.get(deviceId).containsKey(objectId))
|
if (deviceAllocations.get(deviceId).containsKey(objectId))
|
||||||
throw new IllegalStateException("Can't happen ever");
|
throw new IllegalStateException("Can't happen ever");
|
||||||
|
|
||||||
deviceMemoryTracker.subFromAllocation(threadId, deviceId, AllocationUtils.getRequiredMemory(point.getShape()));
|
//deviceMemoryTracker.subFromAllocation(threadId, deviceId, AllocationUtils.getRequiredMemory(point.getShape()));
|
||||||
|
|
||||||
point.setAllocationStatus(AllocationStatus.HOST);
|
point.setAllocationStatus(AllocationStatus.HOST);
|
||||||
|
|
||||||
|
@ -1119,6 +933,9 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public void purgeZeroObject(Long bucketId, Long objectId, AllocationPoint point, boolean copyback) {
|
public void purgeZeroObject(Long bucketId, Long objectId, AllocationPoint point, boolean copyback) {
|
||||||
|
if (1 > 0)
|
||||||
|
throw new UnsupportedOperationException("Pew-pew");
|
||||||
|
|
||||||
forget(point, AllocationStatus.HOST);
|
forget(point, AllocationStatus.HOST);
|
||||||
|
|
||||||
flowController.waitTillReleased(point);
|
flowController.waitTillReleased(point);
|
||||||
|
@ -1127,8 +944,8 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
if (point.getHostPointer() != null) {
|
if (point.getHostPointer() != null) {
|
||||||
free(point, AllocationStatus.HOST);
|
free(point, AllocationStatus.HOST);
|
||||||
|
|
||||||
long reqMem = AllocationUtils.getRequiredMemory(point.getShape()) * -1;
|
//long reqMem = AllocationUtils.getRequiredMemory(point.getShape()) * -1;
|
||||||
zeroUseCounter.addAndGet(reqMem);
|
//zeroUseCounter.addAndGet(reqMem);
|
||||||
}
|
}
|
||||||
|
|
||||||
point.setAllocationStatus(AllocationStatus.DEALLOCATED);
|
point.setAllocationStatus(AllocationStatus.DEALLOCATED);
|
||||||
|
@ -1252,4 +1069,9 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
public FlowController getFlowController() {
|
public FlowController getFlowController() {
|
||||||
return flowController;
|
return flowController;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MemoryProvider getMemoryProvider() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -147,7 +147,7 @@ public class CudaMemoryManager extends BasicMemoryManager {
|
||||||
// Nd4j.getShapeInfoProvider().purgeCache();
|
// Nd4j.getShapeInfoProvider().purgeCache();
|
||||||
|
|
||||||
// purge memory cache
|
// purge memory cache
|
||||||
AtomicAllocator.getInstance().getMemoryHandler().getMemoryProvider().purgeCache();
|
//AtomicAllocator.getInstance().getMemoryHandler().getMemoryProvider().purgeCache();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,303 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.jita.memory.impl;
|
|
||||||
|
|
||||||
import lombok.val;
|
|
||||||
import org.bytedeco.javacpp.Pointer;
|
|
||||||
import org.nd4j.jita.allocator.enums.AllocationStatus;
|
|
||||||
import org.nd4j.jita.allocator.impl.AllocationPoint;
|
|
||||||
import org.nd4j.jita.allocator.impl.AllocationShape;
|
|
||||||
import org.nd4j.jita.allocator.pointers.CudaPointer;
|
|
||||||
import org.nd4j.jita.allocator.pointers.PointersPair;
|
|
||||||
import org.nd4j.jita.allocator.utils.AllocationUtils;
|
|
||||||
import org.nd4j.jita.conf.Configuration;
|
|
||||||
import org.nd4j.jita.conf.CudaEnvironment;
|
|
||||||
import org.nd4j.jita.memory.MemoryProvider;
|
|
||||||
import org.slf4j.Logger;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.slf4j.LoggerFactory;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Queue;
|
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
|
||||||
import java.util.concurrent.ConcurrentLinkedQueue;
|
|
||||||
import java.util.concurrent.Semaphore;
|
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
|
||||||
import java.util.concurrent.atomic.AtomicLong;
|
|
||||||
import org.nd4j.jita.allocator.impl.MemoryTracker;
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This is MemoryProvider implementation, that adds cache for memory reuse purposes. Only host memory is cached for future reuse.
|
|
||||||
*
|
|
||||||
* If some memory chunk gets released via allocator, it'll be probably saved for future reused within same JVM process.
|
|
||||||
*
|
|
||||||
* @author raver119@gmail.com
|
|
||||||
*/
|
|
||||||
public class CudaCachingZeroProvider extends CudaDirectProvider implements MemoryProvider {
|
|
||||||
private static Logger log = LoggerFactory.getLogger(CudaCachingZeroProvider.class);
|
|
||||||
|
|
||||||
protected volatile ConcurrentHashMap<AllocationShape, CacheHolder> zeroCache = new ConcurrentHashMap<>();
|
|
||||||
|
|
||||||
protected final AtomicLong cacheZeroHit = new AtomicLong(0);
|
|
||||||
protected final AtomicLong cacheZeroMiss = new AtomicLong(0);
|
|
||||||
|
|
||||||
protected final AtomicLong cacheDeviceHit = new AtomicLong(0);
|
|
||||||
protected final AtomicLong cacheDeviceMiss = new AtomicLong(0);
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
private final AtomicLong allocRequests = new AtomicLong(0);
|
|
||||||
|
|
||||||
protected final AtomicLong zeroCachedAmount = new AtomicLong(0);
|
|
||||||
protected List<AtomicLong> deviceCachedAmount = new ArrayList<>();
|
|
||||||
|
|
||||||
|
|
||||||
protected final Semaphore singleLock = new Semaphore(1);
|
|
||||||
|
|
||||||
// we don't cache allocations greater then this value
|
|
||||||
//protected final long MAX_SINGLE_ALLOCATION = configuration.getMaximumHostCacheableLength();
|
|
||||||
|
|
||||||
// maximum cached size of memory
|
|
||||||
//protected final long MAX_CACHED_MEMORY = configuration.getMaximumHostCache();
|
|
||||||
|
|
||||||
// memory chunks below this threshold will be guaranteed regardless of number of cache entries
|
|
||||||
// that especially covers all possible variations of shapeInfoDataBuffers in all possible cases
|
|
||||||
protected final long FORCED_CACHE_THRESHOLD = 96;
|
|
||||||
|
|
||||||
// number of preallocation entries for each yet-unknown shape
|
|
||||||
//protected final int PREALLOCATION_LIMIT = configuration.getPreallocationCalls();
|
|
||||||
|
|
||||||
public CudaCachingZeroProvider() {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This method provides PointersPair to memory chunk specified by AllocationShape
|
|
||||||
*
|
|
||||||
* PLEASE NOTE: This method can actually ignore malloc request, and give out previously cached free memory chunk with equal shape.
|
|
||||||
*
|
|
||||||
* @param shape shape of desired memory chunk
|
|
||||||
* @param point target AllocationPoint structure
|
|
||||||
* @param location either HOST or DEVICE
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public PointersPair malloc(AllocationShape shape, AllocationPoint point, AllocationStatus location) {
|
|
||||||
long reqMemory = AllocationUtils.getRequiredMemory(shape);
|
|
||||||
|
|
||||||
if (location == AllocationStatus.HOST && reqMemory < CudaEnvironment.getInstance().getConfiguration().getMaximumHostCacheableLength()) {
|
|
||||||
|
|
||||||
val cache = zeroCache.get(shape);
|
|
||||||
if (cache != null) {
|
|
||||||
val pointer = cache.poll();
|
|
||||||
if (pointer != null) {
|
|
||||||
cacheZeroHit.incrementAndGet();
|
|
||||||
|
|
||||||
// since this memory chunk is going to be used now, remove it's amount from
|
|
||||||
zeroCachedAmount.addAndGet(-1 * reqMemory);
|
|
||||||
|
|
||||||
val pair = new PointersPair();
|
|
||||||
pair.setDevicePointer(new CudaPointer(pointer.address()));
|
|
||||||
pair.setHostPointer(new CudaPointer(pointer.address()));
|
|
||||||
|
|
||||||
point.setAllocationStatus(AllocationStatus.HOST);
|
|
||||||
|
|
||||||
MemoryTracker.getInstance().incrementAllocatedHostAmount(reqMemory);
|
|
||||||
MemoryTracker.getInstance().decrementCachedHostAmount(reqMemory);
|
|
||||||
|
|
||||||
return pair;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
cacheZeroMiss.incrementAndGet();
|
|
||||||
|
|
||||||
if (CudaEnvironment.getInstance().getConfiguration().isUsePreallocation() && zeroCachedAmount.get() < CudaEnvironment.getInstance().getConfiguration().getMaximumHostCache() / 10
|
|
||||||
&& reqMemory < 16 * 1024 * 1024L) {
|
|
||||||
val preallocator = new CachePreallocator(shape, location, CudaEnvironment.getInstance().getConfiguration().getPreallocationCalls());
|
|
||||||
preallocator.start();
|
|
||||||
}
|
|
||||||
|
|
||||||
cacheZeroMiss.incrementAndGet();
|
|
||||||
return super.malloc(shape, point, location);
|
|
||||||
}
|
|
||||||
|
|
||||||
return super.malloc(shape, point, location);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
protected void ensureCacheHolder(AllocationShape shape) {
|
|
||||||
if (!zeroCache.containsKey(shape)) {
|
|
||||||
try {
|
|
||||||
singleLock.acquire();
|
|
||||||
if (!zeroCache.containsKey(shape)) {
|
|
||||||
zeroCache.put(shape, new CacheHolder(shape, zeroCachedAmount));
|
|
||||||
}
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
} finally {
|
|
||||||
singleLock.release();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This method frees specific chunk of memory, described by AllocationPoint passed in.
|
|
||||||
*
|
|
||||||
* PLEASE NOTE: This method can actually ignore free, and keep released memory chunk for future reuse.
|
|
||||||
*
|
|
||||||
* @param point
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void free(AllocationPoint point) {
|
|
||||||
if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
|
|
||||||
super.free(point);
|
|
||||||
} else {
|
|
||||||
// if this point has no allocated chunk - step over it
|
|
||||||
if (point.getHostPointer() == null)
|
|
||||||
return;
|
|
||||||
|
|
||||||
AllocationShape shape = point.getShape();
|
|
||||||
long reqMemory = AllocationUtils.getRequiredMemory(shape);
|
|
||||||
|
|
||||||
// we don't cache too big objects
|
|
||||||
if (reqMemory > CudaEnvironment.getInstance().getConfiguration().getMaximumHostCacheableLength() || zeroCachedAmount.get() >= CudaEnvironment.getInstance().getConfiguration().getMaximumHostCache()) {
|
|
||||||
super.free(point);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
ensureCacheHolder(shape);
|
|
||||||
|
|
||||||
/*
|
|
||||||
Now we should decide if this object can be cached or not
|
|
||||||
*/
|
|
||||||
CacheHolder cache = zeroCache.get(shape);
|
|
||||||
|
|
||||||
// memory chunks < threshold will be cached no matter what
|
|
||||||
if (reqMemory <= FORCED_CACHE_THRESHOLD) {
|
|
||||||
Pointer.memset(point.getHostPointer(), 0, reqMemory);
|
|
||||||
cache.put(new CudaPointer(point.getHostPointer().address()));
|
|
||||||
} else {
|
|
||||||
long cacheEntries = cache.size();
|
|
||||||
long cacheHeight = zeroCache.size();
|
|
||||||
|
|
||||||
// total memory allocated within this bucket
|
|
||||||
long cacheDepth = cacheEntries * reqMemory;
|
|
||||||
|
|
||||||
Pointer.memset(point.getHostPointer(), 0, reqMemory);
|
|
||||||
cache.put(new CudaPointer(point.getHostPointer().address()));
|
|
||||||
}
|
|
||||||
|
|
||||||
MemoryTracker.getInstance().decrementAllocatedHostAmount(reqMemory);
|
|
||||||
MemoryTracker.getInstance().incrementCachedHostAmount(reqMemory);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private float getZeroCacheHitRatio() {
|
|
||||||
long totalHits = cacheZeroHit.get() + cacheZeroMiss.get();
|
|
||||||
float cacheRatio = cacheZeroHit.get() * 100 / (float) totalHits;
|
|
||||||
return cacheRatio;
|
|
||||||
}
|
|
||||||
|
|
||||||
private float getDeviceCacheHitRatio() {
|
|
||||||
long totalHits = cacheDeviceHit.get() + cacheDeviceMiss.get();
|
|
||||||
float cacheRatio = cacheDeviceHit.get() * 100 / (float) totalHits;
|
|
||||||
return cacheRatio;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Deprecated
|
|
||||||
public void printCacheStats() {
|
|
||||||
log.debug("Cached host amount: " + zeroCachedAmount.get());
|
|
||||||
log.debug("Cached device amount: " + deviceCachedAmount.get(0).get());
|
|
||||||
log.debug("Total shapes in cache: " + zeroCache.size());
|
|
||||||
log.debug("Current host hit ratio: " + getZeroCacheHitRatio());
|
|
||||||
log.debug("Current device hit ratio: " + getDeviceCacheHitRatio());
|
|
||||||
}
|
|
||||||
|
|
||||||
protected class CacheHolder {
|
|
||||||
private Queue<Pointer> queue = new ConcurrentLinkedQueue<>();
|
|
||||||
private volatile int counter = 0;
|
|
||||||
private long reqMem = 0;
|
|
||||||
private final AtomicLong allocCounter;
|
|
||||||
|
|
||||||
public CacheHolder(AllocationShape shape, AtomicLong counter) {
|
|
||||||
this.reqMem = AllocationUtils.getRequiredMemory(shape);
|
|
||||||
this.allocCounter = counter;
|
|
||||||
}
|
|
||||||
|
|
||||||
public synchronized int size() {
|
|
||||||
return counter;
|
|
||||||
}
|
|
||||||
|
|
||||||
public synchronized Pointer poll() {
|
|
||||||
val pointer = queue.poll();
|
|
||||||
if (pointer != null)
|
|
||||||
counter--;
|
|
||||||
|
|
||||||
return pointer;
|
|
||||||
}
|
|
||||||
|
|
||||||
public synchronized void put(Pointer pointer) {
|
|
||||||
allocCounter.addAndGet(reqMem);
|
|
||||||
counter++;
|
|
||||||
queue.add(pointer);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
protected class CachePreallocator extends Thread implements Runnable {
|
|
||||||
|
|
||||||
private AllocationShape shape;
|
|
||||||
private AllocationStatus location;
|
|
||||||
private int target;
|
|
||||||
|
|
||||||
public CachePreallocator(AllocationShape shape, AllocationStatus location, int numberOfEntries) {
|
|
||||||
this.shape = shape;
|
|
||||||
this.target = numberOfEntries;
|
|
||||||
this.location = location;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void run() {
|
|
||||||
ensureCacheHolder(shape);
|
|
||||||
|
|
||||||
for (int i = 0; i < target; i++) {
|
|
||||||
val point = new AllocationPoint();
|
|
||||||
|
|
||||||
val pair = CudaCachingZeroProvider.super.malloc(shape, point, this.location);
|
|
||||||
if (this.location == AllocationStatus.HOST) {
|
|
||||||
Pointer pointer = new CudaPointer(pair.getHostPointer().address());
|
|
||||||
CudaCachingZeroProvider.this.zeroCache.get(shape).put(pointer);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void purgeCache() {
|
|
||||||
for (AllocationShape shape : zeroCache.keySet()) {
|
|
||||||
Pointer ptr = null;
|
|
||||||
while ((ptr = zeroCache.get(shape).poll()) != null) {
|
|
||||||
freeHost(ptr);
|
|
||||||
MemoryTracker.getInstance().decrementCachedHostAmount(shape.getNumberOfBytes());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
zeroCachedAmount.set(0);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,239 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.jita.memory.impl;
|
|
||||||
|
|
||||||
import lombok.val;
|
|
||||||
import lombok.var;
|
|
||||||
import org.bytedeco.javacpp.Pointer;
|
|
||||||
import org.nd4j.jita.allocator.enums.AllocationStatus;
|
|
||||||
import org.nd4j.jita.allocator.impl.AllocationPoint;
|
|
||||||
import org.nd4j.jita.allocator.impl.AllocationShape;
|
|
||||||
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
|
||||||
import org.nd4j.jita.allocator.pointers.CudaPointer;
|
|
||||||
import org.nd4j.jita.allocator.pointers.PointersPair;
|
|
||||||
import org.nd4j.jita.allocator.utils.AllocationUtils;
|
|
||||||
import org.nd4j.jita.memory.MemoryProvider;
|
|
||||||
import org.nd4j.linalg.api.memory.AllocationsTracker;
|
|
||||||
import org.nd4j.linalg.api.memory.enums.AllocationKind;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.nativeblas.NativeOps;
|
|
||||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
|
||||||
import org.slf4j.Logger;
|
|
||||||
import org.slf4j.LoggerFactory;
|
|
||||||
import org.nd4j.jita.allocator.impl.MemoryTracker;
|
|
||||||
|
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
|
||||||
import java.util.concurrent.atomic.AtomicLong;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author raver119@gmail.com
|
|
||||||
*/
|
|
||||||
public class CudaDirectProvider implements MemoryProvider {
|
|
||||||
|
|
||||||
protected static final long DEVICE_RESERVED_SPACE = 1024 * 1024 * 50L;
|
|
||||||
private static Logger log = LoggerFactory.getLogger(CudaDirectProvider.class);
|
|
||||||
protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
|
|
||||||
|
|
||||||
protected volatile ConcurrentHashMap<Long, Integer> validator = new ConcurrentHashMap<>();
|
|
||||||
|
|
||||||
|
|
||||||
private AtomicLong emergencyCounter = new AtomicLong(0);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This method provides PointersPair to memory chunk specified by AllocationShape
|
|
||||||
*
|
|
||||||
* @param shape shape of desired memory chunk
|
|
||||||
* @param point target AllocationPoint structure
|
|
||||||
* @param location either HOST or DEVICE
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public PointersPair malloc(AllocationShape shape, AllocationPoint point, AllocationStatus location) {
|
|
||||||
|
|
||||||
//log.info("shape onCreate: {}, target: {}", shape, location);
|
|
||||||
|
|
||||||
switch (location) {
|
|
||||||
case HOST: {
|
|
||||||
long reqMem = AllocationUtils.getRequiredMemory(shape);
|
|
||||||
|
|
||||||
// FIXME: this is WRONG, and directly leads to memleak
|
|
||||||
if (reqMem < 1)
|
|
||||||
reqMem = 1;
|
|
||||||
|
|
||||||
val pointer = nativeOps.mallocHost(reqMem, 0);
|
|
||||||
if (pointer == null)
|
|
||||||
throw new RuntimeException("Can't allocate [HOST] memory: " + reqMem + "; threadId: "
|
|
||||||
+ Thread.currentThread().getId());
|
|
||||||
|
|
||||||
// log.info("Host allocation, Thread id: {}, ReqMem: {}, Pointer: {}", Thread.currentThread().getId(), reqMem, pointer != null ? pointer.address() : null);
|
|
||||||
|
|
||||||
val hostPointer = new CudaPointer(pointer);
|
|
||||||
|
|
||||||
val devicePointerInfo = new PointersPair();
|
|
||||||
if (point.getPointers().getDevicePointer() == null) {
|
|
||||||
point.setAllocationStatus(AllocationStatus.HOST);
|
|
||||||
devicePointerInfo.setDevicePointer(new CudaPointer(hostPointer, reqMem));
|
|
||||||
} else
|
|
||||||
devicePointerInfo.setDevicePointer(point.getDevicePointer());
|
|
||||||
|
|
||||||
devicePointerInfo.setHostPointer(new CudaPointer(hostPointer, reqMem));
|
|
||||||
|
|
||||||
point.setPointers(devicePointerInfo);
|
|
||||||
|
|
||||||
MemoryTracker.getInstance().incrementAllocatedHostAmount(reqMem);
|
|
||||||
|
|
||||||
return devicePointerInfo;
|
|
||||||
}
|
|
||||||
case DEVICE: {
|
|
||||||
// cudaMalloc call
|
|
||||||
val deviceId = AtomicAllocator.getInstance().getDeviceId();
|
|
||||||
long reqMem = AllocationUtils.getRequiredMemory(shape);
|
|
||||||
|
|
||||||
// FIXME: this is WRONG, and directly leads to memleak
|
|
||||||
if (reqMem < 1)
|
|
||||||
reqMem = 1;
|
|
||||||
|
|
||||||
AllocationsTracker.getInstance().markAllocated(AllocationKind.GENERAL, deviceId, reqMem);
|
|
||||||
var pointer = nativeOps.mallocDevice(reqMem, deviceId, 0);
|
|
||||||
if (pointer == null) {
|
|
||||||
// try to purge stuff if we're low on memory
|
|
||||||
purgeCache(deviceId);
|
|
||||||
|
|
||||||
// call for gc
|
|
||||||
Nd4j.getMemoryManager().invokeGc();
|
|
||||||
|
|
||||||
pointer = nativeOps.mallocDevice(reqMem, deviceId, 0);
|
|
||||||
if (pointer == null)
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
val devicePointer = new CudaPointer(pointer);
|
|
||||||
|
|
||||||
var devicePointerInfo = point.getPointers();
|
|
||||||
if (devicePointerInfo == null)
|
|
||||||
devicePointerInfo = new PointersPair();
|
|
||||||
devicePointerInfo.setDevicePointer(new CudaPointer(devicePointer, reqMem));
|
|
||||||
|
|
||||||
point.setAllocationStatus(AllocationStatus.DEVICE);
|
|
||||||
point.setDeviceId(deviceId);
|
|
||||||
MemoryTracker.getInstance().incrementAllocatedAmount(deviceId, reqMem);
|
|
||||||
return devicePointerInfo;
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
throw new IllegalStateException("Unsupported location for malloc: [" + location + "]");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This method frees specific chunk of memory, described by AllocationPoint passed in
|
|
||||||
*
|
|
||||||
* @param point
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void free(AllocationPoint point) {
|
|
||||||
switch (point.getAllocationStatus()) {
|
|
||||||
case HOST: {
|
|
||||||
// cudaFreeHost call here
|
|
||||||
long reqMem = AllocationUtils.getRequiredMemory(point.getShape());
|
|
||||||
val nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
|
|
||||||
|
|
||||||
long result = nativeOps.freeHost(point.getPointers().getHostPointer());
|
|
||||||
if (result == 0) {
|
|
||||||
throw new RuntimeException("Can't deallocate [HOST] memory...");
|
|
||||||
}
|
|
||||||
|
|
||||||
MemoryTracker.getInstance().decrementAllocatedHostAmount(reqMem);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case DEVICE: {
|
|
||||||
if (point.isConstant())
|
|
||||||
return;
|
|
||||||
|
|
||||||
long reqMem = AllocationUtils.getRequiredMemory(point.getShape());
|
|
||||||
|
|
||||||
val nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
|
|
||||||
AllocationsTracker.getInstance().markReleased(AllocationKind.GENERAL, point.getDeviceId(), reqMem);
|
|
||||||
|
|
||||||
val pointers = point.getPointers();
|
|
||||||
|
|
||||||
long result = nativeOps.freeDevice(pointers.getDevicePointer(), 0);
|
|
||||||
if (result == 0)
|
|
||||||
throw new RuntimeException("Can't deallocate [DEVICE] memory...");
|
|
||||||
|
|
||||||
MemoryTracker.getInstance().decrementAllocatedAmount(point.getDeviceId(), reqMem);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw new IllegalStateException("Can't free memory on target [" + point.getAllocationStatus() + "]");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This method checks specified device for specified amount of memory
|
|
||||||
*
|
|
||||||
* @param deviceId
|
|
||||||
* @param requiredMemory
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public boolean pingDeviceForFreeMemory(Integer deviceId, long requiredMemory) {
|
|
||||||
/*
|
|
||||||
long[] totalMem = new long[1];
|
|
||||||
long[] freeMem = new long[1];
|
|
||||||
|
|
||||||
|
|
||||||
JCuda.cudaMemGetInfo(freeMem, totalMem);
|
|
||||||
|
|
||||||
long free = freeMem[0];
|
|
||||||
long total = totalMem[0];
|
|
||||||
long used = total - free;
|
|
||||||
|
|
||||||
/*
|
|
||||||
We don't want to allocate memory if it's too close to the end of available ram.
|
|
||||||
*/
|
|
||||||
//if (configuration != null && used > total * configuration.getMaxDeviceMemoryUsed()) return false;
|
|
||||||
|
|
||||||
/*
|
|
||||||
if (free + requiredMemory < total * 0.85)
|
|
||||||
return true;
|
|
||||||
else return false;
|
|
||||||
*/
|
|
||||||
long freeMem = nativeOps.getDeviceFreeMemory(-1);
|
|
||||||
if (freeMem - requiredMemory < DEVICE_RESERVED_SPACE)
|
|
||||||
return false;
|
|
||||||
else
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected void freeHost(Pointer pointer) {
|
|
||||||
val nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
|
|
||||||
nativeOps.freeHost(pointer);
|
|
||||||
}
|
|
||||||
|
|
||||||
protected void freeDevice(Pointer pointer, int deviceId) {
|
|
||||||
val nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
|
|
||||||
nativeOps.freeDevice(pointer, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
protected void purgeCache(int deviceId) {
|
|
||||||
//
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void purgeCache() {
|
|
||||||
// no-op
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,220 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.jita.memory.impl;
|
|
||||||
|
|
||||||
import lombok.val;
|
|
||||||
import org.bytedeco.javacpp.Pointer;
|
|
||||||
import org.nd4j.jita.allocator.enums.AllocationStatus;
|
|
||||||
import org.nd4j.jita.allocator.impl.AllocationPoint;
|
|
||||||
import org.nd4j.jita.allocator.impl.AllocationShape;
|
|
||||||
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
|
||||||
import org.nd4j.jita.allocator.impl.MemoryTracker;
|
|
||||||
import org.nd4j.jita.allocator.pointers.CudaPointer;
|
|
||||||
import org.nd4j.jita.allocator.pointers.PointersPair;
|
|
||||||
import org.nd4j.jita.allocator.utils.AllocationUtils;
|
|
||||||
import org.nd4j.jita.conf.CudaEnvironment;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.slf4j.Logger;
|
|
||||||
import org.slf4j.LoggerFactory;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
|
||||||
import java.util.concurrent.atomic.AtomicLong;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This MemoryProvider implementation does caching for both host and device memory within predefined limits.
|
|
||||||
*
|
|
||||||
* @author raver119@gmail.com
|
|
||||||
*/
|
|
||||||
public class CudaFullCachingProvider extends CudaCachingZeroProvider {
|
|
||||||
|
|
||||||
//protected final long MAX_GPU_ALLOCATION = configuration.getMaximumSingleDeviceAllocation();
|
|
||||||
|
|
||||||
//protected final long MAX_GPU_CACHE = configuration.getMaximumDeviceCache();
|
|
||||||
|
|
||||||
|
|
||||||
protected volatile ConcurrentHashMap<Integer, ConcurrentHashMap<AllocationShape, CacheHolder>> deviceCache =
|
|
||||||
new ConcurrentHashMap<>();
|
|
||||||
|
|
||||||
|
|
||||||
private static Logger log = LoggerFactory.getLogger(CudaFullCachingProvider.class);
|
|
||||||
|
|
||||||
public CudaFullCachingProvider() {
|
|
||||||
|
|
||||||
init();
|
|
||||||
}
|
|
||||||
|
|
||||||
public void init() {
|
|
||||||
int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
|
|
||||||
|
|
||||||
deviceCachedAmount = new ArrayList<>();
|
|
||||||
|
|
||||||
for (int i = 0; i < numDevices; i++) {
|
|
||||||
deviceCachedAmount.add(new AtomicLong(0));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This method provides PointersPair to memory chunk specified by AllocationShape
|
|
||||||
*
|
|
||||||
* PLEASE NOTE: This method can actually ignore malloc request, and give out previously cached free memory chunk with equal shape.
|
|
||||||
*
|
|
||||||
* @param shape shape of desired memory chunk
|
|
||||||
* @param point target AllocationPoint structure
|
|
||||||
* @param location either HOST or DEVICE
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public PointersPair malloc(AllocationShape shape, AllocationPoint point, AllocationStatus location) {
|
|
||||||
val reqMemory = AllocationUtils.getRequiredMemory(shape);
|
|
||||||
if (location == AllocationStatus.DEVICE && reqMemory < CudaEnvironment.getInstance().getConfiguration().getMaximumDeviceAllocation()) {
|
|
||||||
|
|
||||||
|
|
||||||
val deviceId = AtomicAllocator.getInstance().getDeviceId();
|
|
||||||
ensureDeviceCacheHolder(deviceId, shape);
|
|
||||||
|
|
||||||
val cache = deviceCache.get(deviceId).get(shape);
|
|
||||||
if (cache != null) {
|
|
||||||
val pointer = cache.poll();
|
|
||||||
if (pointer != null) {
|
|
||||||
cacheDeviceHit.incrementAndGet();
|
|
||||||
|
|
||||||
deviceCachedAmount.get(deviceId).addAndGet(-reqMemory);
|
|
||||||
|
|
||||||
val pair = new PointersPair();
|
|
||||||
pair.setDevicePointer(pointer);
|
|
||||||
|
|
||||||
point.setAllocationStatus(AllocationStatus.DEVICE);
|
|
||||||
point.setDeviceId(deviceId);
|
|
||||||
|
|
||||||
|
|
||||||
MemoryTracker.getInstance().incrementAllocatedAmount(deviceId, reqMemory);
|
|
||||||
MemoryTracker.getInstance().decrementCachedAmount(deviceId, reqMemory);
|
|
||||||
|
|
||||||
return pair;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
cacheDeviceMiss.incrementAndGet();
|
|
||||||
return super.malloc(shape, point, location);
|
|
||||||
}
|
|
||||||
return super.malloc(shape, point, location);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This method frees specific chunk of memory, described by AllocationPoint passed in
|
|
||||||
*
|
|
||||||
* PLEASE NOTE: This method can actually ignore free, and keep released memory chunk for future reuse.
|
|
||||||
*
|
|
||||||
* @param point
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void free(AllocationPoint point) {
|
|
||||||
if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
|
|
||||||
if (point.isConstant())
|
|
||||||
return;
|
|
||||||
|
|
||||||
val shape = point.getShape();
|
|
||||||
val deviceId = point.getDeviceId();
|
|
||||||
val address = point.getDevicePointer().address();
|
|
||||||
val reqMemory = AllocationUtils.getRequiredMemory(shape);
|
|
||||||
// we don't cache too big objects
|
|
||||||
|
|
||||||
if (reqMemory > CudaEnvironment.getInstance().getConfiguration().getMaximumDeviceCacheableLength() || deviceCachedAmount.get(deviceId).get() >= CudaEnvironment.getInstance().getConfiguration().getMaximumDeviceCache()) {
|
|
||||||
super.free(point);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
ensureDeviceCacheHolder(deviceId, shape);
|
|
||||||
|
|
||||||
val cache = deviceCache.get(deviceId).get(shape);
|
|
||||||
|
|
||||||
if (point.getDeviceId() != deviceId)
|
|
||||||
throw new RuntimeException("deviceId changed!");
|
|
||||||
|
|
||||||
// memory chunks < threshold will be cached no matter what
|
|
||||||
if (reqMemory <= FORCED_CACHE_THRESHOLD) {
|
|
||||||
cache.put(new CudaPointer(point.getDevicePointer().address()));
|
|
||||||
MemoryTracker.getInstance().incrementCachedAmount(deviceId, reqMemory);
|
|
||||||
MemoryTracker.getInstance().decrementAllocatedAmount(deviceId, reqMemory);
|
|
||||||
return;
|
|
||||||
} else {
|
|
||||||
|
|
||||||
cache.put(new CudaPointer(point.getDevicePointer().address()));
|
|
||||||
|
|
||||||
MemoryTracker.getInstance().incrementCachedAmount(deviceId, reqMemory);
|
|
||||||
MemoryTracker.getInstance().decrementAllocatedAmount(deviceId, reqMemory);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
super.free(point);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This method checks, if storage contains holder for specified shape
|
|
||||||
*
|
|
||||||
* @param deviceId
|
|
||||||
* @param shape
|
|
||||||
*/
|
|
||||||
protected void ensureDeviceCacheHolder(Integer deviceId, AllocationShape shape) {
|
|
||||||
if (!deviceCache.containsKey(deviceId)) {
|
|
||||||
try {
|
|
||||||
synchronized (this) {
|
|
||||||
if (!deviceCache.containsKey(deviceId)) {
|
|
||||||
deviceCache.put(deviceId, new ConcurrentHashMap<AllocationShape, CacheHolder>());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!deviceCache.get(deviceId).containsKey(shape)) {
|
|
||||||
try {
|
|
||||||
singleLock.acquire();
|
|
||||||
|
|
||||||
if (!deviceCache.get(deviceId).containsKey(shape)) {
|
|
||||||
deviceCache.get(deviceId).put(shape, new CacheHolder(shape, deviceCachedAmount.get(deviceId)));
|
|
||||||
}
|
|
||||||
} catch (Exception e) {
|
|
||||||
|
|
||||||
} finally {
|
|
||||||
singleLock.release();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected synchronized void purgeCache(int deviceId) {
|
|
||||||
for (AllocationShape shape : deviceCache.get(deviceId).keySet()) {
|
|
||||||
Pointer ptr = null;
|
|
||||||
while ((ptr = deviceCache.get(deviceId).get(shape).poll()) != null) {
|
|
||||||
freeDevice(ptr, deviceId);
|
|
||||||
MemoryTracker.getInstance().decrementCachedAmount(deviceId, shape.getNumberOfBytes());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
deviceCachedAmount.get(deviceId).set(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public synchronized void purgeCache() {
|
|
||||||
for (Integer device : deviceCache.keySet()) {
|
|
||||||
purgeCache(device);
|
|
||||||
}
|
|
||||||
super.purgeCache();
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -17,34 +17,39 @@
|
||||||
package org.nd4j.linalg.jcublas;
|
package org.nd4j.linalg.jcublas;
|
||||||
|
|
||||||
|
|
||||||
|
import com.google.flatbuffers.FlatBufferBuilder;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
|
import org.bytedeco.javacpp.BytePointer;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.graph.FlatArray;
|
||||||
import org.nd4j.jita.allocator.enums.AllocationStatus;
|
import org.nd4j.jita.allocator.enums.AllocationStatus;
|
||||||
import org.nd4j.jita.allocator.enums.CudaConstants;
|
import org.nd4j.jita.allocator.enums.CudaConstants;
|
||||||
import org.nd4j.jita.allocator.impl.AllocationPoint;
|
import org.nd4j.jita.allocator.impl.AllocationPoint;
|
||||||
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.buffer.DataTypeEx;
|
|
||||||
import org.nd4j.linalg.api.buffer.FloatBuffer;
|
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
import org.nd4j.linalg.api.ndarray.BaseNDArray;
|
import org.nd4j.linalg.api.ndarray.BaseNDArray;
|
||||||
import org.nd4j.linalg.api.ndarray.BaseNDArrayProxy;
|
import org.nd4j.linalg.api.ndarray.BaseNDArrayProxy;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ndarray.JvmShapeInfo;
|
import org.nd4j.linalg.api.ndarray.JvmShapeInfo;
|
||||||
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
|
|
||||||
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
|
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
|
||||||
|
import org.nd4j.linalg.api.ops.util.PrintVariable;
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.jcublas.buffer.CudaLongDataBuffer;
|
import org.nd4j.linalg.jcublas.buffer.CudaLongDataBuffer;
|
||||||
|
import org.nd4j.linalg.jcublas.buffer.CudaUtf8Buffer;
|
||||||
import org.nd4j.linalg.jcublas.context.CudaContext;
|
import org.nd4j.linalg.jcublas.context.CudaContext;
|
||||||
import org.nd4j.linalg.memory.MemcpyDirection;
|
import org.nd4j.linalg.memory.MemcpyDirection;
|
||||||
import org.nd4j.linalg.workspace.WorkspaceUtils;
|
import org.nd4j.linalg.workspace.WorkspaceUtils;
|
||||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
|
|
||||||
|
import java.io.ByteArrayOutputStream;
|
||||||
|
import java.io.DataOutputStream;
|
||||||
|
import java.io.IOException;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.concurrent.atomic.AtomicLong;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
@ -387,10 +392,6 @@ public class JCublasNDArray extends BaseNDArray {
|
||||||
super(data, order);
|
super(data, order);
|
||||||
}
|
}
|
||||||
|
|
||||||
public JCublasNDArray(FloatBuffer floatBuffer, char order) {
|
|
||||||
super(floatBuffer, order);
|
|
||||||
}
|
|
||||||
|
|
||||||
public JCublasNDArray(DataBuffer buffer, int[] shape, int[] strides) {
|
public JCublasNDArray(DataBuffer buffer, int[] shape, int[] strides) {
|
||||||
super(buffer, shape, strides);
|
super(buffer, shape, strides);
|
||||||
}
|
}
|
||||||
|
@ -574,26 +575,16 @@ public class JCublasNDArray extends BaseNDArray {
|
||||||
MemcpyDirection direction = MemcpyDirection.HOST_TO_HOST;
|
MemcpyDirection direction = MemcpyDirection.HOST_TO_HOST;
|
||||||
val prof = PerformanceTracker.getInstance().helperStartTransaction();
|
val prof = PerformanceTracker.getInstance().helperStartTransaction();
|
||||||
|
|
||||||
if (dstPoint.getAllocationStatus() == AllocationStatus.DEVICE && srcPoint.getAllocationStatus() == AllocationStatus.DEVICE) {
|
if (srcPoint.isActualOnDeviceSide()) {
|
||||||
// d2d copy
|
|
||||||
route = 1;
|
route = 1;
|
||||||
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(dstPoint.getDevicePointer(), srcPoint.getDevicePointer(), this.data.length() * this.data.getElementSize(), CudaConstants.cudaMemcpyDeviceToDevice, blocking ? context.getOldStream() : context.getSpecialStream());
|
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(dstPoint.getDevicePointer(), srcPoint.getDevicePointer(), this.data.length() * this.data.getElementSize(), CudaConstants.cudaMemcpyDeviceToDevice, blocking ? context.getOldStream() : context.getSpecialStream());
|
||||||
dstPoint.tickDeviceWrite();
|
dstPoint.tickDeviceWrite();
|
||||||
direction = MemcpyDirection.DEVICE_TO_DEVICE;
|
direction = MemcpyDirection.DEVICE_TO_DEVICE;
|
||||||
} else if (dstPoint.getAllocationStatus() == AllocationStatus.HOST && srcPoint.getAllocationStatus() == AllocationStatus.DEVICE) {
|
} else {
|
||||||
route = 2;
|
|
||||||
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(dstPoint.getHostPointer(), srcPoint.getDevicePointer(), this.data.length() * this.data.getElementSize(), CudaConstants.cudaMemcpyDeviceToHost, blocking ? context.getOldStream() : context.getSpecialStream());
|
|
||||||
dstPoint.tickHostWrite();
|
|
||||||
direction = MemcpyDirection.DEVICE_TO_HOST;
|
|
||||||
} else if (dstPoint.getAllocationStatus() == AllocationStatus.DEVICE && srcPoint.getAllocationStatus() == AllocationStatus.HOST) {
|
|
||||||
route = 3;
|
route = 3;
|
||||||
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(dstPoint.getDevicePointer(), srcPoint.getHostPointer(), this.data.length() * this.data.getElementSize(), CudaConstants.cudaMemcpyHostToDevice, blocking ? context.getOldStream() : context.getSpecialStream());
|
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(dstPoint.getDevicePointer(), srcPoint.getHostPointer(), this.data.length() * this.data.getElementSize(), CudaConstants.cudaMemcpyHostToDevice, blocking ? context.getOldStream() : context.getSpecialStream());
|
||||||
dstPoint.tickDeviceWrite();
|
dstPoint.tickDeviceWrite();
|
||||||
direction = MemcpyDirection.HOST_TO_DEVICE;
|
direction = MemcpyDirection.HOST_TO_DEVICE;
|
||||||
} else {
|
|
||||||
route = 4;
|
|
||||||
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(dstPoint.getHostPointer(), srcPoint.getHostPointer(), this.data.length() * this.data.getElementSize(), CudaConstants.cudaMemcpyHostToHost, blocking ? context.getOldStream() : context.getSpecialStream());
|
|
||||||
dstPoint.tickHostWrite();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -650,30 +641,16 @@ public class JCublasNDArray extends BaseNDArray {
|
||||||
|
|
||||||
Nd4j.getMemoryManager().setCurrentWorkspace(target);
|
Nd4j.getMemoryManager().setCurrentWorkspace(target);
|
||||||
|
|
||||||
// log.info("Leveraging...");
|
|
||||||
|
|
||||||
INDArray copy = null;
|
INDArray copy = null;
|
||||||
if (!this.isView()) {
|
if (!this.isView()) {
|
||||||
//if (1 < 0) {
|
|
||||||
Nd4j.getExecutioner().commit();
|
Nd4j.getExecutioner().commit();
|
||||||
|
|
||||||
DataBuffer buffer = Nd4j.createBuffer(this.length(), false);
|
val buffer = Nd4j.createBuffer(this.length(), false);
|
||||||
|
|
||||||
AllocationPoint pointDst = AtomicAllocator.getInstance().getAllocationPoint(buffer);
|
val pointDst = AtomicAllocator.getInstance().getAllocationPoint(buffer);
|
||||||
AllocationPoint pointSrc = AtomicAllocator.getInstance().getAllocationPoint(this.data);
|
val pointSrc = AtomicAllocator.getInstance().getAllocationPoint(this.data);
|
||||||
|
|
||||||
CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(pointDst, pointSrc);
|
val context = AtomicAllocator.getInstance().getFlowController().prepareAction(pointDst, pointSrc);
|
||||||
/*
|
|
||||||
if (NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(pointDst.getDevicePointer(), 0, 1, 0, context.getOldStream()) == 0)
|
|
||||||
throw new ND4JIllegalStateException("memsetAsync 1 failed");
|
|
||||||
|
|
||||||
context.syncOldStream();
|
|
||||||
|
|
||||||
if (NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(pointSrc.getDevicePointer(), 0, 1, 0, context.getOldStream()) == 0)
|
|
||||||
throw new ND4JIllegalStateException("memsetAsync 2 failed");
|
|
||||||
|
|
||||||
context.syncOldStream();
|
|
||||||
*/
|
|
||||||
|
|
||||||
MemcpyDirection direction = MemcpyDirection.DEVICE_TO_DEVICE;
|
MemcpyDirection direction = MemcpyDirection.DEVICE_TO_DEVICE;
|
||||||
val perfD = PerformanceTracker.getInstance().helperStartTransaction();
|
val perfD = PerformanceTracker.getInstance().helperStartTransaction();
|
||||||
|
@ -690,12 +667,11 @@ public class JCublasNDArray extends BaseNDArray {
|
||||||
|
|
||||||
context.syncOldStream();
|
context.syncOldStream();
|
||||||
|
|
||||||
PerformanceTracker.getInstance().helperRegisterTransaction(pointDst.getDeviceId(), perfD, pointSrc.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);
|
PerformanceTracker.getInstance().helperRegisterTransaction(pointDst.getDeviceId(), perfD, pointSrc.getNumberOfBytes(), direction);
|
||||||
|
|
||||||
copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer());
|
copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer());
|
||||||
|
|
||||||
// tag buffer as valid on device side
|
// tag buffer as valid on device side
|
||||||
pointDst.tickHostRead();
|
|
||||||
pointDst.tickDeviceWrite();
|
pointDst.tickDeviceWrite();
|
||||||
|
|
||||||
AtomicAllocator.getInstance().getFlowController().registerAction(context, pointDst, pointSrc);
|
AtomicAllocator.getInstance().getFlowController().registerAction(context, pointDst, pointSrc);
|
||||||
|
@ -728,6 +704,7 @@ public class JCublasNDArray extends BaseNDArray {
|
||||||
val pointDst = AtomicAllocator.getInstance().getAllocationPoint(buffer);
|
val pointDst = AtomicAllocator.getInstance().getAllocationPoint(buffer);
|
||||||
val pointSrc = AtomicAllocator.getInstance().getAllocationPoint(this.data);
|
val pointSrc = AtomicAllocator.getInstance().getAllocationPoint(this.data);
|
||||||
|
|
||||||
|
|
||||||
val context = AtomicAllocator.getInstance().getFlowController().prepareAction(pointDst, pointSrc);
|
val context = AtomicAllocator.getInstance().getFlowController().prepareAction(pointDst, pointSrc);
|
||||||
|
|
||||||
MemcpyDirection direction = MemcpyDirection.DEVICE_TO_DEVICE;
|
MemcpyDirection direction = MemcpyDirection.DEVICE_TO_DEVICE;
|
||||||
|
@ -764,6 +741,38 @@ public class JCublasNDArray extends BaseNDArray {
|
||||||
return copy;
|
return copy;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected int stringBuffer(FlatBufferBuilder builder, DataBuffer buffer) {
|
||||||
|
Preconditions.checkArgument(buffer.dataType() == DataType.UTF8, "This method can be called on UTF8 buffers only");
|
||||||
|
try {
|
||||||
|
ByteArrayOutputStream bos = new ByteArrayOutputStream();
|
||||||
|
DataOutputStream dos = new DataOutputStream(bos);
|
||||||
|
|
||||||
|
val numWords = this.length();
|
||||||
|
val ub = (CudaUtf8Buffer) buffer;
|
||||||
|
// writing length first
|
||||||
|
val t = length();
|
||||||
|
val ptr = (BytePointer) ub.pointer();
|
||||||
|
|
||||||
|
// now write all strings as bytes
|
||||||
|
for (int i = 0; i < ub.length(); i++) {
|
||||||
|
dos.writeByte(ptr.get(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
val bytes = bos.toByteArray();
|
||||||
|
return FlatArray.createBufferVector(builder, bytes);
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getString(long index) {
|
||||||
|
if (!isS())
|
||||||
|
throw new UnsupportedOperationException("This method is usable only on String dataType, but got [" + this.dataType() + "]");
|
||||||
|
|
||||||
|
return ((CudaUtf8Buffer) data).getString(index);
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@Override
|
@Override
|
||||||
public INDArray convertToHalfs() {
|
public INDArray convertToHalfs() {
|
||||||
|
|
|
@ -18,11 +18,9 @@ package org.nd4j.linalg.jcublas;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import lombok.var;
|
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.buffer.DataTypeEx;
|
import org.nd4j.linalg.api.buffer.DataTypeEx;
|
||||||
import org.nd4j.linalg.api.buffer.Utf8Buffer;
|
|
||||||
import org.nd4j.linalg.api.memory.enums.MemoryKind;
|
import org.nd4j.linalg.api.memory.enums.MemoryKind;
|
||||||
import org.nd4j.linalg.api.ops.custom.Flatten;
|
import org.nd4j.linalg.api.ops.custom.Flatten;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.Concat;
|
import org.nd4j.linalg.api.ops.impl.shape.Concat;
|
||||||
|
@ -34,12 +32,10 @@ import org.nd4j.linalg.jcublas.buffer.*;
|
||||||
import org.nd4j.linalg.memory.MemcpyDirection;
|
import org.nd4j.linalg.memory.MemcpyDirection;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.bytedeco.javacpp.*;
|
import org.bytedeco.javacpp.*;
|
||||||
import org.bytedeco.javacpp.indexer.*;
|
|
||||||
import org.nd4j.jita.allocator.enums.CudaConstants;
|
import org.nd4j.jita.allocator.enums.CudaConstants;
|
||||||
import org.nd4j.jita.allocator.impl.AllocationPoint;
|
import org.nd4j.jita.allocator.impl.AllocationPoint;
|
||||||
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
||||||
import org.nd4j.jita.allocator.pointers.CudaPointer;
|
import org.nd4j.jita.allocator.pointers.CudaPointer;
|
||||||
import org.nd4j.jita.allocator.utils.AllocationUtils;
|
|
||||||
import org.nd4j.jita.conf.CudaEnvironment;
|
import org.nd4j.jita.conf.CudaEnvironment;
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
|
@ -51,19 +47,12 @@ import org.nd4j.linalg.compression.CompressedDataBuffer;
|
||||||
import org.nd4j.linalg.compression.CompressionDescriptor;
|
import org.nd4j.linalg.compression.CompressionDescriptor;
|
||||||
import org.nd4j.linalg.compression.CompressionType;
|
import org.nd4j.linalg.compression.CompressionType;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
import org.nd4j.linalg.factory.BaseNDArrayFactory;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.jcublas.blas.*;
|
import org.nd4j.linalg.jcublas.blas.*;
|
||||||
import org.nd4j.linalg.jcublas.context.CudaContext;
|
import org.nd4j.linalg.jcublas.context.CudaContext;
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
import org.nd4j.nativeblas.*;
|
import org.nd4j.nativeblas.*;
|
||||||
import org.slf4j.Logger;
|
|
||||||
import org.slf4j.LoggerFactory;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.nio.ByteBuffer;
|
|
||||||
import java.nio.ByteOrder;
|
|
||||||
import java.nio.charset.Charset;
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -216,7 +205,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
@Override
|
@Override
|
||||||
public INDArray create(Collection<String> strings, long[] shape, char order) {
|
public INDArray create(Collection<String> strings, long[] shape, char order) {
|
||||||
val pairShape = Nd4j.getShapeInfoProvider().createShapeInformation(shape, order, DataType.UTF8);
|
val pairShape = Nd4j.getShapeInfoProvider().createShapeInformation(shape, order, DataType.UTF8);
|
||||||
val buffer = new Utf8Buffer(strings);
|
val buffer = new CudaUtf8Buffer(strings);
|
||||||
val list = new ArrayList<String>(strings);
|
val list = new ArrayList<String>(strings);
|
||||||
return Nd4j.createArrayFromShapeBuffer(buffer, pairShape);
|
return Nd4j.createArrayFromShapeBuffer(buffer, pairShape);
|
||||||
}
|
}
|
||||||
|
@ -360,8 +349,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray concat(int dimension, INDArray... toConcat) {
|
public INDArray concat(int dimension, INDArray... toConcat) {
|
||||||
if (Nd4j.getExecutioner() instanceof GridExecutioner)
|
Nd4j.getExecutioner().push();
|
||||||
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
|
|
||||||
|
|
||||||
return Nd4j.exec(new Concat(dimension, toConcat))[0];
|
return Nd4j.exec(new Concat(dimension, toConcat))[0];
|
||||||
}
|
}
|
||||||
|
@ -517,9 +505,9 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
AtomicAllocator allocator = AtomicAllocator.getInstance();
|
AtomicAllocator allocator = AtomicAllocator.getInstance();
|
||||||
CudaContext context = allocator.getFlowController().prepareAction(ret, source);
|
CudaContext context = allocator.getFlowController().prepareAction(ret, source);
|
||||||
|
|
||||||
Pointer x = AtomicAllocator.getInstance().getPointer(source, context);
|
val x = ((BaseCudaDataBuffer) source.data()).getOpaqueDataBuffer();
|
||||||
|
val z = ((BaseCudaDataBuffer) ret.data()).getOpaqueDataBuffer();
|
||||||
Pointer xShape = AtomicAllocator.getInstance().getPointer(source.shapeInfoDataBuffer(), context);
|
Pointer xShape = AtomicAllocator.getInstance().getPointer(source.shapeInfoDataBuffer(), context);
|
||||||
Pointer z = AtomicAllocator.getInstance().getPointer(ret, context);
|
|
||||||
Pointer zShape = AtomicAllocator.getInstance().getPointer(ret.shapeInfoDataBuffer(), context);
|
Pointer zShape = AtomicAllocator.getInstance().getPointer(ret.shapeInfoDataBuffer(), context);
|
||||||
|
|
||||||
PointerPointer extras = new PointerPointer(AddressRetriever.retrieveHostPointer(ret.shapeInfoDataBuffer()),
|
PointerPointer extras = new PointerPointer(AddressRetriever.retrieveHostPointer(ret.shapeInfoDataBuffer()),
|
||||||
|
@ -545,14 +533,8 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
|
|
||||||
|
|
||||||
nativeOps.pullRows(extras,
|
nativeOps.pullRows(extras,
|
||||||
null,
|
x, (LongPointer) source.shapeInfoDataBuffer().addressPointer(), (LongPointer) xShape,
|
||||||
(LongPointer) source.shapeInfoDataBuffer().addressPointer(),
|
z, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), (LongPointer) zShape,
|
||||||
x,
|
|
||||||
(LongPointer) xShape,
|
|
||||||
null,
|
|
||||||
(LongPointer) ret.shapeInfoDataBuffer().addressPointer(),
|
|
||||||
z,
|
|
||||||
(LongPointer) zShape,
|
|
||||||
indexes.length,
|
indexes.length,
|
||||||
(LongPointer) pIndex,
|
(LongPointer) pIndex,
|
||||||
(LongPointer) tadShapeInfo,
|
(LongPointer) tadShapeInfo,
|
||||||
|
@ -601,7 +583,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
throw new ND4JIllegalStateException("All arrays should have equal length for averaging");
|
throw new ND4JIllegalStateException("All arrays should have equal length for averaging");
|
||||||
|
|
||||||
AllocationPoint point = allocator.getAllocationPoint(arrays[i]);
|
AllocationPoint point = allocator.getAllocationPoint(arrays[i]);
|
||||||
xPointers[i] = point.getPointers().getDevicePointer().address();
|
xPointers[i] = point.getDevicePointer().address();
|
||||||
point.tickDeviceWrite();
|
point.tickDeviceWrite();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -710,7 +692,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
throw new ND4JIllegalStateException("All arrays should have equal length for averaging");
|
throw new ND4JIllegalStateException("All arrays should have equal length for averaging");
|
||||||
|
|
||||||
AllocationPoint point = allocator.getAllocationPoint(arrays[i]);
|
AllocationPoint point = allocator.getAllocationPoint(arrays[i]);
|
||||||
xPointers[i] = point.getPointers().getDevicePointer().address();
|
xPointers[i] = point.getDevicePointer().address();
|
||||||
point.tickDeviceWrite();
|
point.tickDeviceWrite();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1324,11 +1306,11 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
PointerPointer extraz = new PointerPointer(null, // not used
|
PointerPointer extraz = new PointerPointer(null, // not used
|
||||||
context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer());
|
context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer());
|
||||||
|
|
||||||
|
val x = ((BaseCudaDataBuffer) tensor.data()).getOpaqueDataBuffer();
|
||||||
|
|
||||||
|
|
||||||
nativeOps.tear(extraz,
|
nativeOps.tear(extraz,
|
||||||
null,
|
x, (LongPointer) tensor.shapeInfoDataBuffer().addressPointer(), (LongPointer) AtomicAllocator.getInstance().getPointer(tensor.shapeInfoDataBuffer(), context),
|
||||||
(LongPointer) tensor.shapeInfoDataBuffer().addressPointer(),
|
|
||||||
AtomicAllocator.getInstance().getPointer(tensor, context),
|
|
||||||
(LongPointer) AtomicAllocator.getInstance().getPointer(tensor.shapeInfoDataBuffer(), context),
|
|
||||||
new PointerPointer(AtomicAllocator.getInstance().getPointer(tempX, context)),
|
new PointerPointer(AtomicAllocator.getInstance().getPointer(tempX, context)),
|
||||||
(LongPointer) AtomicAllocator.getInstance().getPointer(result[0].shapeInfoDataBuffer(), context),
|
(LongPointer) AtomicAllocator.getInstance().getPointer(result[0].shapeInfoDataBuffer(), context),
|
||||||
(LongPointer) AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context),
|
(LongPointer) AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context),
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -46,6 +46,10 @@ public class CudaBfloat16DataBuffer extends BaseCudaDataBuffer {
|
||||||
super(pointer, specialPointer, indexer, length);
|
super(pointer, specialPointer, indexer, length);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public CudaBfloat16DataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) {
|
||||||
|
super(buffer, dataType, length, offset);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Base constructor
|
* Base constructor
|
||||||
*
|
*
|
||||||
|
@ -128,18 +132,6 @@ public class CudaBfloat16DataBuffer extends BaseCudaDataBuffer {
|
||||||
super(data, copy, offset);
|
super(data, copy, offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
public CudaBfloat16DataBuffer(byte[] data, long length) {
|
|
||||||
super(data, length, DataType.BFLOAT16);
|
|
||||||
}
|
|
||||||
|
|
||||||
public CudaBfloat16DataBuffer(ByteBuffer buffer, long length) {
|
|
||||||
super(buffer, (int) length, DataType.BFLOAT16);
|
|
||||||
}
|
|
||||||
|
|
||||||
public CudaBfloat16DataBuffer(ByteBuffer buffer, long length, long offset) {
|
|
||||||
super(buffer, length, offset, DataType.BFLOAT16);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void assign(long[] indices, double[] data, boolean contiguous, long inc) {
|
public void assign(long[] indices, double[] data, boolean contiguous, long inc) {
|
||||||
|
|
||||||
|
|
|
@ -50,6 +50,10 @@ public class CudaBoolDataBuffer extends BaseCudaDataBuffer {
|
||||||
super(pointer, specialPointer, indexer, length);
|
super(pointer, specialPointer, indexer, length);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public CudaBoolDataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) {
|
||||||
|
super(buffer, dataType, length, offset);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Base constructor
|
* Base constructor
|
||||||
*
|
*
|
||||||
|
@ -132,18 +136,6 @@ public class CudaBoolDataBuffer extends BaseCudaDataBuffer {
|
||||||
super(data, copy, offset);
|
super(data, copy, offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
public CudaBoolDataBuffer(byte[] data, long length) {
|
|
||||||
super(data, length, DataType.HALF);
|
|
||||||
}
|
|
||||||
|
|
||||||
public CudaBoolDataBuffer(ByteBuffer buffer, long length) {
|
|
||||||
super(buffer, (int) length, DataType.HALF);
|
|
||||||
}
|
|
||||||
|
|
||||||
public CudaBoolDataBuffer(ByteBuffer buffer, long length, long offset) {
|
|
||||||
super(buffer, length, offset, DataType.HALF);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected DataBuffer create(long length) {
|
protected DataBuffer create(long length) {
|
||||||
return new CudaBoolDataBuffer(length);
|
return new CudaBoolDataBuffer(length);
|
||||||
|
|
|
@ -49,6 +49,10 @@ public class CudaByteDataBuffer extends BaseCudaDataBuffer {
|
||||||
super(pointer, specialPointer, indexer, length);
|
super(pointer, specialPointer, indexer, length);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public CudaByteDataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) {
|
||||||
|
super(buffer, dataType, length, offset);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Base constructor
|
* Base constructor
|
||||||
*
|
*
|
||||||
|
@ -131,18 +135,6 @@ public class CudaByteDataBuffer extends BaseCudaDataBuffer {
|
||||||
super(data, copy, offset);
|
super(data, copy, offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
public CudaByteDataBuffer(byte[] data, long length) {
|
|
||||||
super(data, length, DataType.HALF);
|
|
||||||
}
|
|
||||||
|
|
||||||
public CudaByteDataBuffer(ByteBuffer buffer, long length) {
|
|
||||||
super(buffer, (int) length, DataType.HALF);
|
|
||||||
}
|
|
||||||
|
|
||||||
public CudaByteDataBuffer(ByteBuffer buffer, long length, long offset) {
|
|
||||||
super(buffer, length, offset, DataType.HALF);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected DataBuffer create(long length) {
|
protected DataBuffer create(long length) {
|
||||||
return new CudaByteDataBuffer(length);
|
return new CudaByteDataBuffer(length);
|
||||||
|
|
|
@ -49,6 +49,10 @@ public class CudaDoubleDataBuffer extends BaseCudaDataBuffer {
|
||||||
super(pointer, specialPointer, indexer, length);
|
super(pointer, specialPointer, indexer, length);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public CudaDoubleDataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) {
|
||||||
|
super(buffer, dataType, length, offset);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Base constructor
|
* Base constructor
|
||||||
*
|
*
|
||||||
|
@ -138,18 +142,6 @@ public class CudaDoubleDataBuffer extends BaseCudaDataBuffer {
|
||||||
super(data, copy, offset);
|
super(data, copy, offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
public CudaDoubleDataBuffer(byte[] data, long length) {
|
|
||||||
super(data, length, DataType.DOUBLE);
|
|
||||||
}
|
|
||||||
|
|
||||||
public CudaDoubleDataBuffer(ByteBuffer buffer, long length) {
|
|
||||||
super(buffer, (int) length, DataType.DOUBLE);
|
|
||||||
}
|
|
||||||
|
|
||||||
public CudaDoubleDataBuffer(ByteBuffer buffer, long length, long offset) {
|
|
||||||
super(buffer, length, offset, DataType.DOUBLE);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected DataBuffer create(long length) {
|
protected DataBuffer create(long length) {
|
||||||
return new CudaDoubleDataBuffer(length);
|
return new CudaDoubleDataBuffer(length);
|
||||||
|
@ -210,14 +202,7 @@ public class CudaDoubleDataBuffer extends BaseCudaDataBuffer {
|
||||||
this.length = n;
|
this.length = n;
|
||||||
this.elementSize = 8;
|
this.elementSize = 8;
|
||||||
|
|
||||||
//wrappedBuffer = ByteBuffer.allocateDirect(length() * getElementSize());
|
this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, DataType.DOUBLE), false);
|
||||||
//wrappedBuffer.order(ByteOrder.nativeOrder());
|
|
||||||
|
|
||||||
this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this,
|
|
||||||
new AllocationShape(length, elementSize, DataType.DOUBLE), false);
|
|
||||||
this.trackingPoint = allocationPoint.getObjectId();
|
|
||||||
//this.wrappedBuffer = allocationPoint.getPointers().getHostPointer().asByteBuffer();
|
|
||||||
//this.wrappedBuffer.order(ByteOrder.nativeOrder());
|
|
||||||
|
|
||||||
setData(arr);
|
setData(arr);
|
||||||
}
|
}
|
||||||
|
|
|
@ -50,6 +50,10 @@ public class CudaFloatDataBuffer extends BaseCudaDataBuffer {
|
||||||
super(pointer, specialPointer, indexer, length);
|
super(pointer, specialPointer, indexer, length);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public CudaFloatDataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) {
|
||||||
|
super(buffer, dataType, length, offset);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Base constructor
|
* Base constructor
|
||||||
*
|
*
|
||||||
|
@ -133,19 +137,6 @@ public class CudaFloatDataBuffer extends BaseCudaDataBuffer {
|
||||||
super(data, copy, offset);
|
super(data, copy, offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
public CudaFloatDataBuffer(byte[] data, long length) {
|
|
||||||
super(data, length, DataType.FLOAT);
|
|
||||||
}
|
|
||||||
|
|
||||||
public CudaFloatDataBuffer(ByteBuffer buffer, long length) {
|
|
||||||
super(buffer, (int) length, DataType.FLOAT);
|
|
||||||
}
|
|
||||||
|
|
||||||
public CudaFloatDataBuffer(ByteBuffer buffer, long length, long offset) {
|
|
||||||
super(buffer, length, offset, DataType.FLOAT);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected DataBuffer create(long length) {
|
protected DataBuffer create(long length) {
|
||||||
return new CudaFloatDataBuffer(length);
|
return new CudaFloatDataBuffer(length);
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue