diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PReLULayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PReLULayer.java index 17275e574..428a4c7a6 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PReLULayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PReLULayer.java @@ -121,6 +121,7 @@ public class PReLULayer extends BaseLayer { public static class Builder extends FeedForwardLayer.Builder { public Builder(){ + //Default to 0s, and don't inherit global default this.weightInitFn = new WeightInitConstant(0); } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/NegativeHolder.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/NegativeHolder.java index 3c311eed3..7f1b682ab 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/NegativeHolder.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/NegativeHolder.java @@ -20,7 +20,7 @@ import lombok.Getter; import lombok.NonNull; import org.deeplearning4j.models.word2vec.VocabWord; import org.deeplearning4j.models.word2vec.wordstore.VocabCache; -import org.nd4j.linalg.api.buffer.FloatBuffer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -63,7 +63,7 @@ public class NegativeHolder implements Serializable { protected void makeTable(int tableSize, double power) { int vocabSize = vocab.numWords(); - table = Nd4j.create(new FloatBuffer(tableSize)); + table = Nd4j.create(DataType.FLOAT, tableSize); double trainWordsPow = 0.0; for (String word : vocab.words()) { trainWordsPow += Math.pow(vocab.wordFrequency(word), power); diff --git a/libnd4j/blas/NDArray.h b/libnd4j/blas/NDArray.h index f89ee6e1d..a426e471e 100644 --- a/libnd4j/blas/NDArray.h +++ b/libnd4j/blas/NDArray.h @@ -42,6 +42,8 @@ #include #include #include +#include +#include namespace nd4j { @@ -301,14 +303,11 @@ namespace nd4j { * @param writeList * @param readList */ - // TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list - static void registerSpecialUse(const std::initializer_list& writeList, const std::initializer_list& readList); - static void prepareSpecialUse(const std::initializer_list& writeList, const std::initializer_list& readList, bool synchronizeWritables = false); - - // TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list - static void registerPrimaryUse(const std::initializer_list& writeList, const std::initializer_list& readList); - static void preparePrimaryUse(const std::initializer_list& writeList, const std::initializer_list& readList, bool synchronizeWritables = false); + static void registerSpecialUse(const std::vector& writeList, const std::vector& readList); + static void prepareSpecialUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables = false); + static void registerPrimaryUse(const std::vector& writeList, const std::vector& readList); + static void preparePrimaryUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables = false); /** * This method returns buffer pointer offset by given number of elements, wrt own data type diff --git a/libnd4j/blas/NDArray.hpp b/libnd4j/blas/NDArray.hpp index a08db4f16..b79f52fb3 100644 --- a/libnd4j/blas/NDArray.hpp +++ b/libnd4j/blas/NDArray.hpp @@ -223,6 +223,8 @@ NDArray::NDArray(std::shared_ptr buffer, const ShapeDescriptor& desc setShapeInfo(descriptor); _buffer = buffer; + + _isView = offset > 0 || _length * DataTypeUtils::sizeOf(_dataType) < buffer->getLenInBytes(); } //////////////////////////////////////////////////////////////////////// @@ -288,6 +290,8 @@ NDArray::NDArray(std::shared_ptr buffer, const char order, const std setShapeInfo(ShapeDescriptor(buffer->getDataType(), order, shape)); _buffer = buffer; + + _isView = _length * DataTypeUtils::sizeOf(_dataType) < buffer->getLenInBytes(); } //////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/blas/NativeOps.h b/libnd4j/blas/NativeOps.h index a85473c7f..10893c08d 100755 --- a/libnd4j/blas/NativeOps.h +++ b/libnd4j/blas/NativeOps.h @@ -68,6 +68,7 @@ bool verbose = false; #include #include #include +#include #include #include #include @@ -76,6 +77,8 @@ bool verbose = false; #include #include +typedef nd4j::InteropDataBuffer OpaqueDataBuffer; + extern "C" { /** @@ -118,11 +121,9 @@ ND4J_EXPORT void setTADThreshold(int num); */ ND4J_EXPORT void execIndexReduceScalar(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo); + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo); /** * @@ -137,13 +138,10 @@ ND4J_EXPORT void execIndexReduceScalar(Nd4jPointer *extraPointers, */ ND4J_EXPORT void execIndexReduce(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape); + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape); /** * @@ -160,28 +158,20 @@ ND4J_EXPORT void execIndexReduce(Nd4jPointer *extraPointers, ND4J_EXPORT void execBroadcast( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape); + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape); ND4J_EXPORT void execBroadcastBool( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape); + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape); /** * @@ -198,23 +188,17 @@ ND4J_EXPORT void execBroadcastBool( ND4J_EXPORT void execPairwiseTransform( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams); ND4J_EXPORT void execPairwiseTransformBool( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams); /** @@ -228,36 +212,28 @@ ND4J_EXPORT void execPairwiseTransformBool( */ ND4J_EXPORT void execReduceFloat(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo); + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo); ND4J_EXPORT void execReduceSame(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo); + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo); ND4J_EXPORT void execReduceBool(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo); + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo); ND4J_EXPORT void execReduceLong(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo); + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo); /** * @@ -270,46 +246,34 @@ ND4J_EXPORT void execReduceLong(Nd4jPointer *extraPointers, */ ND4J_EXPORT void execReduceFloat2(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape); + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape); ND4J_EXPORT void execReduceSame2(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape); + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape); ND4J_EXPORT void execReduceBool2(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape); + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape); ND4J_EXPORT void execReduceLong2(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape); + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape); /** * @@ -324,13 +288,10 @@ ND4J_EXPORT void execReduceLong2(Nd4jPointer *extraPointers, */ ND4J_EXPORT void execReduce3(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParamsVals, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo); + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo); /** * @@ -343,13 +304,10 @@ ND4J_EXPORT void execReduce3(Nd4jPointer *extraPointers, */ ND4J_EXPORT void execReduce3Scalar(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParamsVals, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo); + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo); /** * * @param opNum @@ -365,30 +323,22 @@ ND4J_EXPORT void execReduce3Scalar(Nd4jPointer *extraPointers, */ ND4J_EXPORT void execReduce3Tad(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParamsVals, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *yTadOnlyShapeInfo, Nd4jLong *yTadOffsets); ND4J_EXPORT void execReduce3All(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParamsVals, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets); @@ -405,22 +355,16 @@ ND4J_EXPORT void execReduce3All(Nd4jPointer *extraPointers, */ ND4J_EXPORT void execScalar(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hScalar, Nd4jLong *hSscalarShapeInfo, - void *dScalar, Nd4jLong *dSscalarShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbScalar, Nd4jLong *hSscalarShapeInfo, Nd4jLong *dSscalarShapeInfo, void *extraParams); ND4J_EXPORT void execScalarBool(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hScalar, Nd4jLong *hSscalarShapeInfo, - void *dScalar, Nd4jLong *dSscalarShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbScalar, Nd4jLong *hSscalarShapeInfo, Nd4jLong *dSscalarShapeInfo, void *extraParams); /** @@ -432,11 +376,9 @@ ND4J_EXPORT void execScalarBool(Nd4jPointer *extraPointers, */ ND4J_EXPORT void execSummaryStatsScalar(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, bool biasCorrected); /** * @@ -449,11 +391,9 @@ ND4J_EXPORT void execSummaryStatsScalar(Nd4jPointer *extraPointers, */ ND4J_EXPORT void execSummaryStats(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, bool biasCorrected); /** * @@ -468,13 +408,10 @@ ND4J_EXPORT void execSummaryStats(Nd4jPointer *extraPointers, */ ND4J_EXPORT void execSummaryStatsTad(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, bool biasCorrected, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets); @@ -490,42 +427,32 @@ ND4J_EXPORT void execSummaryStatsTad(Nd4jPointer *extraPointers, */ ND4J_EXPORT void execTransformFloat(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams); ND4J_EXPORT void execTransformSame(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams); ND4J_EXPORT void execTransformBool(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams); ND4J_EXPORT void execTransformAny(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams); ND4J_EXPORT void execTransformStrict(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams); /** @@ -543,29 +470,21 @@ ND4J_EXPORT void execTransformStrict(Nd4jPointer *extraPointers, */ ND4J_EXPORT void execScalarTad(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hScalars, Nd4jLong *hScalarShapeInfo, - void *dScalars, Nd4jLong *dScalarShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbScalars, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo, void *extraParams, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ); ND4J_EXPORT void execScalarBoolTad(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hScalars, Nd4jLong *hScalarShapeInfo, - void *dScalars, Nd4jLong *dScalarShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbScalars, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo, void *extraParams, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ); @@ -904,10 +823,8 @@ ND4J_EXPORT void deleteTadPack(OpaqueTadPack* ptr); * @param zTadOffsets */ ND4J_EXPORT void pullRows(Nd4jPointer *extraPointers, - void *x, Nd4jLong *xShapeInfo, - void *dx, Nd4jLong *dxShapeInfo, - void *z, Nd4jLong *zShapeInfo, - void *dz, Nd4jLong *dzShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *xShapeInfo, Nd4jLong *dxShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *zShapeInfo, Nd4jLong *dzShapeInfo, Nd4jLong n, Nd4jLong *indexes, Nd4jLong *tadShapeInfo, @@ -1086,8 +1003,7 @@ ND4J_EXPORT void execAggregateBatch(Nd4jPointer *extraPointers, ND4J_EXPORT void execRandom(Nd4jPointer *extraPointers, int opNum, Nd4jPointer state, - void *hZ, Nd4jLong *hZShapeBuffer, - void *dZ, Nd4jLong *dZShapeBuffer, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeBuffer, Nd4jLong *dZShapeBuffer, void *extraArguments); /** @@ -1106,12 +1022,9 @@ ND4J_EXPORT void execRandom(Nd4jPointer *extraPointers, ND4J_EXPORT void execRandom3(Nd4jPointer *extraPointers, int opNum, Nd4jPointer state, - void *hX, Nd4jLong *hXShapeBuffer, - void *dX, Nd4jLong *dXShapeBuffer, - void *hY, Nd4jLong *hYShapeBuffer, - void *dY, Nd4jLong *dYShapeBuffer, - void *hZ, Nd4jLong *hZShapeBuffer, - void *dZ, Nd4jLong *dZShapeBuffer, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeBuffer, Nd4jLong *dXShapeBuffer, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeBuffer, Nd4jLong *dYShapeBuffer, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeBuffer, Nd4jLong *dZShapeBuffer, void *extraArguments); /** @@ -1128,10 +1041,8 @@ ND4J_EXPORT void execRandom3(Nd4jPointer *extraPointers, ND4J_EXPORT void execRandom2(Nd4jPointer *extraPointers, int opNum, Nd4jPointer state, - void *hX, Nd4jLong *hXShapeBuffer, - void *dX, Nd4jLong *dXShapeBuffer, - void *hZ, Nd4jLong *hZShapeBuffer, - void *dZ, Nd4jLong *dZShapeBuffer, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeBuffer, Nd4jLong *dXShapeBuffer, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeBuffer, Nd4jLong *dZShapeBuffer, void *extraArguments); @@ -1174,52 +1085,6 @@ ND4J_EXPORT void reSeedBuffer(Nd4jPointer *extraPointers, */ ND4J_EXPORT void destroyRandom(Nd4jPointer ptrRandom); -/** - * Grid operations - */ - - - - -/** - * - * @param extras - * @param opTypeA - * @param opNumA - * @param opTypeB - * @param opNumB - * @param N - * @param dx - * @param xShapeInfo - * @param dy - * @param yShapeInfo - * @param dz - * @param zShapeInfo - * @param extraA - * @param extraB - * @param scalarA - * @param scalarB - */ - /* -ND4J_EXPORT void execMetaPredicateShape(Nd4jPointer *extras, - const int opTypeA, - const int opNumA, - const int opTypeB, - const int opNumB, - Nd4jLong N, - void *hX, Nd4jLong *hXShapeBuffer, - void *dX, Nd4jLong *dXShapeBuffer, - void *hY, Nd4jLong *hYShapeBuffer, - void *dY, Nd4jLong *dYShapeBuffer, - void *hZ, Nd4jLong *hZShapeBuffer, - void *dZ, Nd4jLong *dZShapeBuffer, - void *extraA, - void *extraB, - double scalarA, - double scalarB); - -*/ - } /** @@ -1561,11 +1426,10 @@ ND4J_EXPORT Nd4jPointer pointerForAddress(Nd4jLong address); * @return */ ND4J_EXPORT void tear(Nd4jPointer *extraPointers, - void *x, Nd4jLong *xShapeInfo, - void *dx, Nd4jLong *dxShapeInfo, - Nd4jPointer *targets, Nd4jLong *zShapeInfo, - Nd4jLong *tadShapeInfo, - Nd4jLong *tadOffsets); + OpaqueDataBuffer *dbX, Nd4jLong *xShapeInfo, Nd4jLong *dxShapeInfo, + Nd4jPointer *targets, Nd4jLong *zShapeInfo, + Nd4jLong *tadShapeInfo, + Nd4jLong *tadOffsets); ND4J_EXPORT Nd4jLong encodeBitmap(Nd4jPointer *extraPointers, void *dx, Nd4jLong *xShapeInfo, Nd4jLong N, int *dz, float threshold); ND4J_EXPORT void decodeBitmap(Nd4jPointer *extraPointers, void *dx, Nd4jLong N, void *dz, Nd4jLong *zShapeInfo); @@ -1739,6 +1603,8 @@ ND4J_EXPORT void markGraphContextInplace(OpaqueContext* ptr, bool reallyInplace) ND4J_EXPORT void setGraphContextCudaContext(OpaqueContext* ptr, void *stream, void *reductionPointer, void *allocationPointer); ND4J_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo); ND4J_EXPORT void setGraphContextOutputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo); +ND4J_EXPORT void setGraphContextInputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo); +ND4J_EXPORT void setGraphContextOutputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo); ND4J_EXPORT void setGraphContextTArguments(OpaqueContext* ptr, double *arguments, int numberOfArguments); ND4J_EXPORT void setGraphContextIArguments(OpaqueContext* ptr, Nd4jLong *arguments, int numberOfArguments); ND4J_EXPORT void setGraphContextBArguments(OpaqueContext* ptr, bool *arguments, int numberOfArguments); @@ -1766,6 +1632,28 @@ ND4J_EXPORT Nd4jPointer lcCopyStream(OpaqueLaunchContext* lc); ND4J_EXPORT Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc); ND4J_EXPORT Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc); +ND4J_EXPORT OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth); +ND4J_EXPORT OpaqueDataBuffer* dbCreateView(OpaqueDataBuffer *dataBuffer, Nd4jLong length, Nd4jLong offset); +ND4J_EXPORT Nd4jPointer dbPrimaryBuffer(OpaqueDataBuffer *dataBuffer); +ND4J_EXPORT Nd4jPointer dbSpecialBuffer(OpaqueDataBuffer *dataBuffer); +ND4J_EXPORT void dbExpandBuffer(OpaqueDataBuffer *dataBuffer, Nd4jLong elements); +ND4J_EXPORT void dbAllocatePrimaryBuffer(OpaqueDataBuffer *dataBuffer); +ND4J_EXPORT void dbAllocateSpecialBuffer(OpaqueDataBuffer *dataBuffer); +ND4J_EXPORT void dbSetPrimaryBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer primaryBuffer, Nd4jLong numBytes); +ND4J_EXPORT void dbSetSpecialBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer specialBuffer, Nd4jLong numBytes); +ND4J_EXPORT void dbSyncToSpecial(OpaqueDataBuffer *dataBuffer); +ND4J_EXPORT void dbSyncToPrimary(OpaqueDataBuffer *dataBuffer); +ND4J_EXPORT int dbLocality(OpaqueDataBuffer *dataBuffer); +ND4J_EXPORT int dbDeviceId(OpaqueDataBuffer *dataBuffer); +ND4J_EXPORT void dbSetDeviceId(OpaqueDataBuffer *dataBuffer, int deviceId); +ND4J_EXPORT void dbTickHostRead(OpaqueDataBuffer *dataBuffer); +ND4J_EXPORT void dbTickHostWrite(OpaqueDataBuffer *dataBuffer); +ND4J_EXPORT void dbTickDeviceRead(OpaqueDataBuffer *dataBuffer); +ND4J_EXPORT void dbTickDeviceWrite(OpaqueDataBuffer *dataBuffer); +ND4J_EXPORT void dbClose(OpaqueDataBuffer *dataBuffer); +ND4J_EXPORT void deleteDataBuffer(OpaqueDataBuffer *dataBuffer); +ND4J_EXPORT void dbExpand(OpaqueDataBuffer *dataBuffer, Nd4jLong elements); + ND4J_EXPORT int binaryLevel(); ND4J_EXPORT int optimalLevel(); diff --git a/libnd4j/blas/cpu/NDArray.cpp b/libnd4j/blas/cpu/NDArray.cpp index 9dd2ed967..9bdf41a16 100644 --- a/libnd4j/blas/cpu/NDArray.cpp +++ b/libnd4j/blas/cpu/NDArray.cpp @@ -184,16 +184,16 @@ void NDArray::synchronize(const char* msg) const { // no-op } -void NDArray::prepareSpecialUse(const std::initializer_list& writeList, const std::initializer_list& readList, bool synchronizeWritables) { +void NDArray::prepareSpecialUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables) { // no-op } -void NDArray::registerSpecialUse(const std::initializer_list& writeList, const std::initializer_list& readList) { +void NDArray::registerSpecialUse(const std::vector& writeList, const std::vector& readList) { // no-op } -void NDArray::preparePrimaryUse(const std::initializer_list& writeList, const std::initializer_list& readList, bool synchronizeWritables) { +void NDArray::preparePrimaryUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables) { // no-op } -void NDArray::registerPrimaryUse(const std::initializer_list& writeList, const std::initializer_list& readList) { +void NDArray::registerPrimaryUse(const std::vector& writeList, const std::vector& readList) { // no-op } diff --git a/libnd4j/blas/cpu/NativeOps.cpp b/libnd4j/blas/cpu/NativeOps.cpp index 956802961..e150a2039 100644 --- a/libnd4j/blas/cpu/NativeOps.cpp +++ b/libnd4j/blas/cpu/NativeOps.cpp @@ -102,13 +102,11 @@ void setTADThreshold(int num) { */ void execIndexReduceScalar(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { try { - NativeOpExecutioner::execIndexReduceScalar(nullptr, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, dZ, dZShapeInfo); + NativeOpExecutioner::execIndexReduceScalar(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -127,15 +125,12 @@ void execIndexReduceScalar(Nd4jPointer *extraPointers, * @param dimensionLength */ void execIndexReduce(Nd4jPointer *extraPointers,int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { try { - auto dimension = reinterpret_cast(hDimension); + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, @@ -144,17 +139,17 @@ void execIndexReduce(Nd4jPointer *extraPointers,int opNum, auto hTADShapeInfo = tadPack.primaryShapeInfo(); auto hTADOffsets = tadPack.primaryOffsets(); - auto hz = reinterpret_cast(hZ); + auto hz = reinterpret_cast(dbZ->primary()); NativeOpExecutioner::execIndexReduce(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, extraParams, hz, hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, dimension, dimensionLength, @@ -181,16 +176,12 @@ void execIndexReduce(Nd4jPointer *extraPointers,int opNum, */ void execBroadcast(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape) { + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { try { - auto dimension = reinterpret_cast(hDimension); + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, @@ -205,16 +196,16 @@ void execBroadcast(Nd4jPointer *extraPointers, NativeOpExecutioner::execBroadcast(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, - hY, + dbY->primary(), hYShapeInfo, - dY, + dbY->special(), dYShapeInfo, - hZ, hZShapeInfo, - dZ, dZShapeInfo, + dbZ->primary(), hZShapeInfo, + dbZ->special(), dZShapeInfo, dimension, dimensionLength, hTADShapeInfo, hTADOffsets, hTADShapeInfoZ, hTADOffsetsZ); } catch (std::exception &e) { @@ -225,17 +216,13 @@ void execBroadcast(Nd4jPointer *extraPointers, void execBroadcastBool(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape) { + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { try { - auto dimension = reinterpret_cast(hDimension); + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, @@ -250,16 +237,16 @@ void execBroadcastBool(Nd4jPointer *extraPointers, NativeOpExecutioner::execBroadcastBool(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, - hY, + dbY->primary(), hYShapeInfo, - dY, + dbY->special(), dYShapeInfo, - hZ, hZShapeInfo, - dZ, dZShapeInfo, + dbZ->primary(), hZShapeInfo, + dbZ->special(), dZShapeInfo, extraParams, dimension, dimensionLength, hTADShapeInfo, hTADOffsets, hTADShapeInfoZ, @@ -285,27 +272,24 @@ void execBroadcastBool(Nd4jPointer *extraPointers, void execPairwiseTransform( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams) { try { NativeOpExecutioner::execPairwiseTransform(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, - hY, + dbY->primary(), hYShapeInfo, - dY, + dbY->special(), dYShapeInfo, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, extraParams); } catch (std::exception &e) { @@ -317,28 +301,25 @@ void execPairwiseTransform( void execPairwiseTransformBool( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams) { try { NativeOpExecutioner::execPairwiseBoolTransform(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, - hY, + dbY->primary(), hYShapeInfo, - dY, + dbY->special(), dYShapeInfo, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, extraParams); } catch (std::exception &e) { @@ -359,23 +340,21 @@ void execPairwiseTransformBool( void execReduceFloat( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { try { NativeOpExecutioner::execReduceFloatScalar(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, extraParams, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); @@ -386,23 +365,21 @@ void execReduceFloat( void execReduceSame( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { try { NativeOpExecutioner::execReduceSameScalar(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, extraParams, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); @@ -413,22 +390,20 @@ void execReduceSame( void execReduceBool( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { try { NativeOpExecutioner::execReduceBoolScalar(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, extraParams, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); @@ -439,22 +414,20 @@ void execReduceBool( void execReduceLong( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { try { NativeOpExecutioner::execReduceLongScalar(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, extraParams, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); @@ -473,15 +446,12 @@ void execReduceLong( */ void execReduceFloat2(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { try { - auto dimension = reinterpret_cast(hDimension); + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, @@ -491,14 +461,14 @@ void execReduceFloat2(Nd4jPointer *extraPointers, auto hTADOffsets = tadPackX.primaryOffsets(); NativeOpExecutioner::execReduceFloat(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, extraParams, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, dimension, dimensionLength, @@ -512,15 +482,12 @@ void execReduceFloat2(Nd4jPointer *extraPointers, void execReduceBool2(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { try { - auto dimension = reinterpret_cast(hDimension); + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, @@ -530,14 +497,14 @@ void execReduceBool2(Nd4jPointer *extraPointers, auto hTADOffsets = tadPack.primaryOffsets(); NativeOpExecutioner::execReduceBool(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, extraParams, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, dimension, dimensionLength, @@ -551,15 +518,12 @@ void execReduceBool2(Nd4jPointer *extraPointers, void execReduceSame2(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { try { - auto dimension = reinterpret_cast(hDimension); + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, @@ -569,14 +533,14 @@ void execReduceSame2(Nd4jPointer *extraPointers, auto hTADOffsets = tadPack.primaryOffsets(); NativeOpExecutioner::execReduceSame(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, extraParams, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, dimension, dimensionLength, @@ -590,15 +554,12 @@ void execReduceSame2(Nd4jPointer *extraPointers, void execReduceLong2(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { try { - auto dimension = reinterpret_cast(hDimension); + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, @@ -608,14 +569,14 @@ void execReduceLong2(Nd4jPointer *extraPointers, auto hTADOffsets = tadPack.primaryOffsets(); NativeOpExecutioner::execReduceLong(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, extraParams, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, dimension, dimensionLength, @@ -640,16 +601,13 @@ void execReduceLong2(Nd4jPointer *extraPointers, */ void execReduce3(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo) { + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { try { - NativeOpExecutioner::execReduce3(nullptr, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, hYShapeInfo, - dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo); + NativeOpExecutioner::execReduce3(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParams, dbY->primary(), hYShapeInfo, + dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -666,16 +624,13 @@ void execReduce3(Nd4jPointer *extraPointers, * @param hYShapeInfo */ void execReduce3Scalar(Nd4jPointer *extraPointers,int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo) { + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { try { - NativeOpExecutioner::execReduce3Scalar(nullptr, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, - hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo); + NativeOpExecutioner::execReduce3Scalar(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParams, dbY->primary(), + hYShapeInfo, dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -696,24 +651,20 @@ void execReduce3Scalar(Nd4jPointer *extraPointers,int opNum, */ void execReduce3Tad(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *yTadOnlyShapeInfo, Nd4jLong *yTadOffsets) { try { - auto dimension = reinterpret_cast(hDimension); + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); if (extraPointers == nullptr || extraPointers[2] == 0) { - NativeOpExecutioner::execReduce3(LaunchContext::defaultContext(), opNum, hX, hXShapeInfo, dX, dXShapeInfo, - extraParams, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, + NativeOpExecutioner::execReduce3(LaunchContext::defaultContext(), opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, + extraParams, dbY->primary(), hYShapeInfo, dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets); } else { @@ -724,9 +675,9 @@ void execReduce3Tad(Nd4jPointer *extraPointers, auto hTADShapeInfo = tadPack.primaryShapeInfo(); auto hTADOffsets = tadPack.primaryOffsets(); - NativeOpExecutioner::execReduce3TAD(LaunchContext::defaultContext(), opNum, hX, hXShapeInfo, dX, - dXShapeInfo, extraParams, hY, hYShapeInfo, dY, dYShapeInfo, hZ, - hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, hTADShapeInfo, + NativeOpExecutioner::execReduce3TAD(LaunchContext::defaultContext(), opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, dbY->primary(), hYShapeInfo, dbY->special(), dYShapeInfo, dbZ->primary(), + hZShapeInfo, dbZ->special(), dZShapeInfo, dimension, dimensionLength, hTADShapeInfo, hTADOffsets, nullptr, nullptr); } } catch (std::exception &e) { @@ -753,27 +704,24 @@ bool isBlasVersionMatches(int major, int minor, int build) { void execScalar( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hScalar, Nd4jLong *hScalarShapeInfo, - void *dScalar, Nd4jLong *dScalarShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbScalar, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo, void *extraParams) { try { NativeOpExecutioner::execScalar(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, - hScalar, + dbScalar->primary(), hScalarShapeInfo, - dScalar, + dbScalar->special(), dScalarShapeInfo, extraParams); } catch (std::exception &e) { @@ -785,27 +733,24 @@ void execScalar( void execScalarBool( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hScalar, Nd4jLong *hScalarShapeInfo, - void *dScalar, Nd4jLong *dScalarShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbScalar, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo, void *extraParams) { try { NativeOpExecutioner::execScalarBool(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, - hScalar, + dbScalar->primary(), hScalarShapeInfo, - dScalar, + dbScalar->special(), dScalarShapeInfo, extraParams); } catch (std::exception &e) { @@ -823,23 +768,21 @@ void execScalarBool( */ void execSummaryStatsScalar(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, bool biasCorrected) { try { NativeOpExecutioner::execSummaryStatsScalar(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, extraParams, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, biasCorrected); } catch (std::exception &e) { @@ -858,23 +801,21 @@ void execSummaryStatsScalar(Nd4jPointer *extraPointers, */ void execSummaryStats(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, bool biasCorrected) { try { NativeOpExecutioner::execSummaryStats(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, extraParams, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, biasCorrected); } catch (std::exception &e) { @@ -895,30 +836,27 @@ void execSummaryStats(Nd4jPointer *extraPointers, */ void execSummaryStatsTad(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, bool biasCorrected, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { try { - auto dimension = reinterpret_cast(hDimension); + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); NativeOpExecutioner::execSummaryStats(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, extraParams, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, dimension, dimensionLength, @@ -944,21 +882,19 @@ void execSummaryStatsTad(Nd4jPointer *extraPointers, void execTransformFloat( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams) { try { NativeOpExecutioner::execTransformFloat(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dZ, + dbX->special(), dXShapeInfo, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, extraParams, nullptr, @@ -972,21 +908,19 @@ void execTransformFloat( void execTransformSame( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams) { try { NativeOpExecutioner::execTransformSame(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, extraParams, nullptr, @@ -1000,21 +934,19 @@ void execTransformSame( void execTransformBool( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams) { try { NativeOpExecutioner::execTransformBool(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, extraParams, nullptr, @@ -1028,21 +960,19 @@ void execTransformBool( void execTransformAny( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams) { try { NativeOpExecutioner::execTransformAny(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, extraParams, nullptr, @@ -1056,21 +986,19 @@ void execTransformAny( void execTransformStrict( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams) { try { NativeOpExecutioner::execTransformStrict(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, extraParams, nullptr, @@ -1083,27 +1011,23 @@ void execTransformStrict( void execReduce3All(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParamsVals, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) { try { - auto dimension = reinterpret_cast(hDimension); + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); - NativeOpExecutioner::execReduce3All(nullptr, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParamsVals, hY, - hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, + NativeOpExecutioner::execReduce3All(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParamsVals, dbY->primary(), + hYShapeInfo, dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, dimension, dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); @@ -1398,10 +1322,8 @@ void pullRowsGeneric(void *vx, } void pullRows(Nd4jPointer *extraPointers, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, Nd4jLong n, Nd4jLong *indexes, Nd4jLong *tadShapeInfo, @@ -1411,7 +1333,7 @@ void pullRows(Nd4jPointer *extraPointers, try { auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); - BUILD_SINGLE_SELECTOR(xType, pullRowsGeneric, (hX, hXShapeInfo, hZ, hZShapeInfo, n, indexes, tadShapeInfo, tadOffsets, zTadShapeInfo, zTadOffsets), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(xType, pullRowsGeneric, (dbX->primary(), hXShapeInfo, dbZ->primary(), hZShapeInfo, n, indexes, tadShapeInfo, tadOffsets, zTadShapeInfo, zTadOffsets), LIBND4J_TYPES); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1459,8 +1381,7 @@ void tearGeneric(void *vx, } void tear(Nd4jPointer *extraPointers, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, Nd4jPointer *targets, Nd4jLong *hZShapeInfo, Nd4jLong *tadShapeInfo, @@ -1468,7 +1389,7 @@ void tear(Nd4jPointer *extraPointers, try { auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); - BUILD_SINGLE_SELECTOR(xType, tearGeneric, (hX, hXShapeInfo, targets, hZShapeInfo, tadShapeInfo, tadOffsets), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(xType, tearGeneric, (dbX->primary(), hXShapeInfo, targets, hZShapeInfo, tadShapeInfo, tadOffsets), LIBND4J_TYPES); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1653,35 +1574,31 @@ int getDevice() { void execScalarTad(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hScalars, Nd4jLong *hScalarShapeInfo, - void *dScalars, Nd4jLong *dScalarShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbScalars, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo, void *extraParams, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { try { - auto dimension = reinterpret_cast(hDimension); + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); NativeOpExecutioner::execScalar(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, extraParams, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, - hScalars, + dbScalars->primary(), hScalarShapeInfo, - dScalars, + dbScalars->special(), dScalarShapeInfo, dimension, shape::length(hDimensionShape), @@ -1697,35 +1614,31 @@ void execScalarTad(Nd4jPointer *extraPointers, void execScalarBoolTad(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hScalars, Nd4jLong *hScalarShapeInfo, - void *dScalars, Nd4jLong *dScalarShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbScalars, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo, void *extraParams, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { try { - auto dimension = reinterpret_cast(hDimension); + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); NativeOpExecutioner::execScalarBool(nullptr, opNum, - hX, + dbX->primary(), hXShapeInfo, - dX, + dbX->special(), dXShapeInfo, extraParams, - hZ, + dbZ->primary(), hZShapeInfo, - dZ, + dbZ->special(), dZShapeInfo, - hScalars, + dbScalars->primary(), hScalarShapeInfo, - dScalars, + dbScalars->special(), dScalarShapeInfo, dimension, dimensionLength, @@ -1809,11 +1722,10 @@ void execAggregateBatch(Nd4jPointer *extraPointers, void execRandom(Nd4jPointer *extraPointers, int opNum, Nd4jPointer state, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraArguments) { try { - NativeOpExecutioner::execRandom(nullptr, opNum, state, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments); + NativeOpExecutioner::execRandom(nullptr, opNum, state, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, extraArguments); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1823,15 +1735,12 @@ void execRandom(Nd4jPointer *extraPointers, void execRandom3(Nd4jPointer *extraPointers, int opNum, Nd4jPointer state, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraArguments) { try { - NativeOpExecutioner::execRandom(nullptr, opNum, state, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments); + NativeOpExecutioner::execRandom(nullptr, opNum, state, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, dbY->primary(), hYShapeInfo, dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, extraArguments); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1841,13 +1750,11 @@ void execRandom3(Nd4jPointer *extraPointers, void execRandom2(Nd4jPointer *extraPointers, int opNum, Nd4jPointer state, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraArguments) { try { - NativeOpExecutioner::execRandom(nullptr, opNum, state, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments); + NativeOpExecutioner::execRandom(nullptr, opNum, state, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, extraArguments); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -2863,6 +2770,15 @@ void setGraphContextInputArray(nd4j::graph::Context* ptr, int index, void *buffe void setGraphContextOutputArray(nd4j::graph::Context* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) { ptr->setOutputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo); } + +void setGraphContextInputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo) { + ptr->setInputArray(index, buffer, shapeInfo, specialShapeInfo); +} + +void setGraphContextOutputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo) { + ptr->setOutputArray(index, buffer, shapeInfo, specialShapeInfo); +} + void setGraphContextTArguments(nd4j::graph::Context* ptr, double *arguments, int numberOfArguments) { ptr->setTArguments(arguments, numberOfArguments); } @@ -3176,6 +3092,91 @@ bool isOptimalRequirementsMet() { #endif } +OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth) { + auto dtype = DataTypeUtils::fromInt(dataType); + return new nd4j::InteropDataBuffer(elements * DataTypeUtils::sizeOf(dtype) , dtype, allocateBoth); +} + +Nd4jPointer dbPrimaryBuffer(OpaqueDataBuffer *dataBuffer) { + return dataBuffer->primary(); +} + +Nd4jPointer dbSpecialBuffer(OpaqueDataBuffer *dataBuffer) { + return dataBuffer->special(); +} + +void deleteDataBuffer(OpaqueDataBuffer *dataBuffer) { + delete dataBuffer; +} + +void dbSetPrimaryBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer primaryBuffer, Nd4jLong numBytes) { + dataBuffer->setPrimary(primaryBuffer, numBytes); +} + +void dbSetSpecialBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer specialBuffer, Nd4jLong numBytes) { + dataBuffer->setSpecial(specialBuffer, numBytes); +} + +void dbAllocatePrimaryBuffer(OpaqueDataBuffer *dataBuffer) { + dataBuffer->dataBuffer()->allocatePrimary(); +} + +void dbAllocateSpecialBuffer(OpaqueDataBuffer *dataBuffer) { + dataBuffer->dataBuffer()->allocateSpecial(); +} + +void dbExpandBuffer(OpaqueDataBuffer *dataBuffer, Nd4jLong elements) { + dataBuffer->dataBuffer()->expand(elements * DataTypeUtils::sizeOf(dataBuffer->dataBuffer()->getDataType())); +} + +OpaqueDataBuffer* dbCreateView(OpaqueDataBuffer *dataBuffer, Nd4jLong length, Nd4jLong offset) { + return new InteropDataBuffer(*dataBuffer, length, offset); +} + +void dbSyncToSpecial(OpaqueDataBuffer *dataBuffer) { + dataBuffer->dataBuffer()->syncToSpecial(); +} + +void dbSyncToPrimary(OpaqueDataBuffer *dataBuffer) { + dataBuffer->dataBuffer()->syncToPrimary(nullptr); +} + +void dbTickHostRead(OpaqueDataBuffer *dataBuffer) { + dataBuffer->dataBuffer()->readPrimary(); +} + +void dbTickHostWrite(OpaqueDataBuffer *dataBuffer) { + dataBuffer->dataBuffer()->writePrimary(); +} + +void dbTickDeviceRead(OpaqueDataBuffer *dataBuffer) { + dataBuffer->dataBuffer()->readSpecial(); +} + +void dbTickDeviceWrite(OpaqueDataBuffer *dataBuffer) { + dataBuffer->dataBuffer()->writeSpecial(); +} + +void dbExpand(OpaqueDataBuffer *dataBuffer, Nd4jLong elements) { + dataBuffer->expand(elements); +} + +int dbLocality(OpaqueDataBuffer *dataBuffer) { + return 0; +} + +void dbSetDeviceId(OpaqueDataBuffer *dataBuffer, int deviceId) { + dataBuffer->setDeviceId(deviceId); +} + +int dbDeviceId(OpaqueDataBuffer *dataBuffer) { + return dataBuffer->deviceId(); +} + +void dbClose(OpaqueDataBuffer *dataBuffer) { + dataBuffer->getDataBuffer()->close(); +} + BUILD_SINGLE_TEMPLATE(template void pullRowsGeneric, (void *, Nd4jLong*, void*, Nd4jLong*, const int, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void tearGeneric, (void *, Nd4jLong*, Nd4jPointer*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void shuffleGeneric, (void**, Nd4jLong**, void**, Nd4jLong**, int, int*, Nd4jLong**, Nd4jLong**), LIBND4J_TYPES); diff --git a/libnd4j/blas/cuda/NDArray.cu b/libnd4j/blas/cuda/NDArray.cu index 48c7a7933..81c8070b3 100644 --- a/libnd4j/blas/cuda/NDArray.cu +++ b/libnd4j/blas/cuda/NDArray.cu @@ -236,7 +236,7 @@ void NDArray::synchronize(const char* msg) const { } //////////////////////////////////////////////////////////////////////// -void NDArray::prepareSpecialUse(const std::initializer_list& writeList, const std::initializer_list& readList, bool synchronizeWritables) { +void NDArray::prepareSpecialUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables) { for (const auto& a : readList) if(a != nullptr) @@ -252,7 +252,7 @@ void NDArray::prepareSpecialUse(const std::initializer_list& wri } //////////////////////////////////////////////////////////////////////// -void NDArray::registerSpecialUse(const std::initializer_list& writeList, const std::initializer_list& readList) { +void NDArray::registerSpecialUse(const std::vector& writeList, const std::vector& readList) { for (const auto& p : readList) if(p != nullptr) @@ -264,7 +264,7 @@ void NDArray::registerSpecialUse(const std::initializer_list& wr } //////////////////////////////////////////////////////////////////////// -void NDArray::preparePrimaryUse(const std::initializer_list& writeList, const std::initializer_list& readList, bool synchronizeWritables) { +void NDArray::preparePrimaryUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables) { for (const auto& a : readList) if(a != nullptr) @@ -280,7 +280,7 @@ void NDArray::preparePrimaryUse(const std::initializer_list& wri } //////////////////////////////////////////////////////////////////////// -void NDArray::registerPrimaryUse(const std::initializer_list& writeList, const std::initializer_list& readList) { +void NDArray::registerPrimaryUse(const std::vector& writeList, const std::vector& readList) { for (const auto& p : readList) if(p != nullptr) diff --git a/libnd4j/blas/cuda/NativeOps.cu b/libnd4j/blas/cuda/NativeOps.cu index 3d86d94f8..16c888c0a 100755 --- a/libnd4j/blas/cuda/NativeOps.cu +++ b/libnd4j/blas/cuda/NativeOps.cu @@ -229,17 +229,19 @@ public: void execPairwiseTransform( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execPairwiseTransform(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY, - dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams); + NativeOpExecutioner::execPairwiseTransform(&lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), extraParams); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -249,17 +251,21 @@ void execPairwiseTransform( Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execPairwiseTransformBool(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execPairwiseBoolTransform(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, - dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams); + NativeOpExecutioner::execPairwiseBoolTransform(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + extraParams); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -269,16 +275,21 @@ void execPairwiseTransformBool(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execSummaryStatsScalar(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, bool biasCorrected) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execSummaryStatsScalar(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, - hZShapeInfo, dZ, dZShapeInfo, biasCorrected); + NativeOpExecutioner::execSummaryStatsScalar(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + extraParams, + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + biasCorrected); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -288,22 +299,16 @@ void execSummaryStatsScalar(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execBroadcastBool(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *extraParams, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape) { + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + void *extraParams, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { try { - //Nd4jLong *tadOnlyShapeInfo = reinterpret_cast(extraPointers[0]); - //Nd4jLong *tadOffsets = reinterpret_cast(extraPointers[1]); - //Nd4jLong *tadOnlyShapeInfoZ = reinterpret_cast(extraPointers[2]); - //Nd4jLong *tadOffsetsZ = reinterpret_cast(extraPointers[3]); + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); - auto dimension = reinterpret_cast(dDimension); + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); auto hTADShapeInfo = reinterpret_cast(extraPointers[9]); @@ -313,10 +318,15 @@ void execBroadcastBool(Nd4jPointer *extraPointers, auto tadOffsetsZ = reinterpret_cast(extraPointers[13]); LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execBroadcastBool(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY, - dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, dimension, - dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, - tadOffsetsZ); + NativeOpExecutioner::execBroadcastBool(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + extraParams, + dimension, dimensionLength, + tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -338,16 +348,15 @@ void execBroadcastBool(Nd4jPointer *extraPointers, void execBroadcast( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape) { + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { try { - auto dimension = reinterpret_cast(dDimension); + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); cudaStream_t *stream = reinterpret_cast(extraPointers[1]); @@ -362,13 +371,15 @@ void execBroadcast( auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); - if (nd4j::Environment::getInstance()->isDebugAndVerbose()) - printf("F3 opNum:[%i]\n", opNum); - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execBroadcast(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY, - dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, - tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); + NativeOpExecutioner::execBroadcast(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + dimension, dimensionLength, + tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -388,15 +399,19 @@ void execBroadcast( //////////////////////////////////////////////////////////////////////// void execReduceFloat(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduceFloatScalar(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, - hZShapeInfo, dZ, dZShapeInfo); + NativeOpExecutioner::execReduceFloatScalar(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + extraParams, + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -406,15 +421,19 @@ void execReduceFloat(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execReduceSame(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduceSameScalar(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, - hZShapeInfo, dZ, dZShapeInfo); + NativeOpExecutioner::execReduceSameScalar(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + extraParams, + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -424,25 +443,30 @@ void execReduceSame(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execReduceSame2(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { try { - auto dimension = reinterpret_cast(dDimension); + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, - reinterpret_cast(hDimension), + dimension, shape::length(hDimensionShape)); LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduceSame(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, - dZ, dZShapeInfo, dimension, dimensionLength, tadPack.specialShapeInfo(), - tadPack.specialOffsets()); + NativeOpExecutioner::execReduceSame(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + extraParams, + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + dimension, dimensionLength, + tadPack.specialShapeInfo(), tadPack.specialOffsets()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -452,25 +476,30 @@ void execReduceSame2(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execReduceLong2(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { try { - auto dimension = reinterpret_cast(dDimension); + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, - reinterpret_cast(hDimension), + dimension, shape::length(hDimensionShape)); LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduceLong(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, - dZ, dZShapeInfo, dimension, dimensionLength, tadPack.specialShapeInfo(), - tadPack.specialOffsets()); + NativeOpExecutioner::execReduceLong(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + extraParams, + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + dimension, dimensionLength, + tadPack.specialShapeInfo(), tadPack.specialOffsets()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -480,19 +509,16 @@ void execReduceLong2(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execReduceLong(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + auto stream = reinterpret_cast(extraPointers[1]); auto hTADShapeInfo = reinterpret_cast(extraPointers[9]); auto dTADShapeInfo = reinterpret_cast(extraPointers[10]); - if (nd4j::Environment::getInstance()->isDebugAndVerbose()) - printf("LF7 opNum:[%i]\n", opNum); - auto reductionPointer = reinterpret_cast(extraPointers[4]); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); @@ -507,11 +533,15 @@ void execReduceLong(Nd4jPointer *extraPointers, dim3 launchDims(numBlocks, blockWidth, 32768); BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, - ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, - dZ, dZShapeInfo, hXShapeInfo, nullptr, 0, reductionPointer, - dTADShapeInfo), LIBND4J_TYPES, LONG_TYPES); + ::execReduceScalar(launchDims, stream, opNum, + dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), hXShapeInfo, + extraParams, + dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), hXShapeInfo, + nullptr, 0, reductionPointer, dTADShapeInfo), LIBND4J_TYPES, LONG_TYPES); nd4j::DebugHelper::checkErrorCode(stream, "execReduceLong(...) failed"); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -521,25 +551,30 @@ void execReduceLong(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execReduceBool2(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { try { - auto dimension = reinterpret_cast(dDimension); + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, - reinterpret_cast(hDimension), + dimension, shape::length(hDimensionShape)); LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduceBool(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, - dZ, dZShapeInfo, dimension, dimensionLength, tadPack.specialShapeInfo(), - tadPack.specialOffsets()); + NativeOpExecutioner::execReduceBool(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + extraParams, + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + dimension, dimensionLength, + tadPack.specialShapeInfo(), tadPack.specialOffsets()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -549,19 +584,16 @@ void execReduceBool2(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execReduceBool(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + auto stream = reinterpret_cast(extraPointers[1]); auto hTADShapeInfo = reinterpret_cast(extraPointers[9]); auto dTADShapeInfo = reinterpret_cast(extraPointers[10]); - if (nd4j::Environment::getInstance()->isDebugAndVerbose()) - printf("BF7 opNum:[%i]\n", opNum); - auto reductionPointer = reinterpret_cast(extraPointers[4]); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); @@ -576,11 +608,15 @@ void execReduceBool(Nd4jPointer *extraPointers, dim3 launchDims(numBlocks, blockWidth, 32768); BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, - ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, - dZ, dZShapeInfo, hZShapeInfo, nullptr, 0, reductionPointer, - dTADShapeInfo), LIBND4J_TYPES, BOOL_TYPES); + ::execReduceScalar(launchDims, stream, opNum, + dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), hXShapeInfo, + extraParams, + dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), hZShapeInfo, + nullptr, 0, reductionPointer, dTADShapeInfo), LIBND4J_TYPES, BOOL_TYPES); nd4j::DebugHelper::checkErrorCode(stream, "execReduceBool(...) failed"); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -601,25 +637,30 @@ void execReduceBool(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execIndexReduce(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { try { - auto dimension = reinterpret_cast(dDimension); + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, - reinterpret_cast(hDimension), + dimension, shape::length(hDimensionShape)); LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execIndexReduce(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, - dZ, dZShapeInfo, dimension, dimensionLength, tadPack.specialShapeInfo(), - tadPack.specialOffsets()); + NativeOpExecutioner::execIndexReduce(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + extraParams, + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + (int *) dbDimension->special(), dimensionLength, + tadPack.specialShapeInfo(), tadPack.specialOffsets()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -638,25 +679,30 @@ void execIndexReduce(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execReduceFloat2(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape) { + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape) { try { - auto dimension = reinterpret_cast(dDimension); + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, - reinterpret_cast(hDimension), + dimension, shape::length(hDimensionShape)); LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduceFloat(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, - dZ, dZShapeInfo, dimension, dimensionLength, tadPack.specialShapeInfo(), - tadPack.specialOffsets()); + NativeOpExecutioner::execReduceFloat(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + extraParams, + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + dimension, dimensionLength, + tadPack.specialShapeInfo(), tadPack.specialOffsets()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -674,15 +720,19 @@ void execReduceFloat2(Nd4jPointer *extraPointers, void execIndexReduceScalar( Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo){ + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo){ try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execIndexReduceScalar(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, - hZShapeInfo, dZ, dZShapeInfo); + NativeOpExecutioner::execIndexReduceScalar(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + extraParams, + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -691,18 +741,23 @@ void execIndexReduceScalar( //////////////////////////////////////////////////////////////////////// void execTransformSame(Nd4jPointer *extraPointers,int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + auto tadShapeInfo = reinterpret_cast(extraPointers != nullptr ? extraPointers[0] : nullptr); auto tadOffsets = reinterpret_cast(extraPointers != nullptr ? extraPointers[1] : nullptr); LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execTransformSame(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, - dZShapeInfo, extraParams, tadShapeInfo, tadOffsets); + NativeOpExecutioner::execTransformSame(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + extraParams, + tadShapeInfo, tadOffsets); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -711,18 +766,23 @@ void execTransformSame(Nd4jPointer *extraPointers,int opNum, //////////////////////////////////////////////////////////////////////// void execTransformBool(Nd4jPointer *extraPointers,int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + auto tadShapeInfo = reinterpret_cast(extraPointers != nullptr ? extraPointers[0] : nullptr); auto tadOffsets = reinterpret_cast(extraPointers != nullptr ? extraPointers[1] : nullptr); LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execTransformBool(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, - dZShapeInfo, extraParams, tadShapeInfo, tadOffsets); + NativeOpExecutioner::execTransformBool(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + extraParams, + tadShapeInfo, tadOffsets); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -731,19 +791,24 @@ void execTransformBool(Nd4jPointer *extraPointers,int opNum, //////////////////////////////////////////////////////////////////////// void execTransformAny(Nd4jPointer *extraPointers,int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *extraParams) { + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + void *extraParams) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + auto stream = reinterpret_cast(extraPointers[1]); auto streamSpecial = reinterpret_cast(extraPointers[4]); LaunchContext lc(stream, streamSpecial, extraPointers[5], extraPointers[3], reinterpret_cast(extraPointers[6])); - NativeOpExecutioner::execTransformAny(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, - dZShapeInfo, extraParams, nullptr, nullptr); + NativeOpExecutioner::execTransformAny(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + extraParams, + nullptr, nullptr); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -752,18 +817,23 @@ void execTransformAny(Nd4jPointer *extraPointers,int opNum, //////////////////////////////////////////////////////////////////////// void execTransformStrict(Nd4jPointer *extraPointers,int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *extraParams) { + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + void *extraParams) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + auto tadShapeInfo = reinterpret_cast(extraPointers != nullptr ? extraPointers[10] : nullptr); auto tadOffsets = reinterpret_cast(extraPointers != nullptr ? extraPointers[11] : nullptr); LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execTransformStrict(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, - dZShapeInfo, extraParams, tadShapeInfo, tadOffsets); + NativeOpExecutioner::execTransformStrict(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + extraParams, + tadShapeInfo, tadOffsets); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -772,18 +842,23 @@ void execTransformStrict(Nd4jPointer *extraPointers,int opNum, //////////////////////////////////////////////////////////////////////// void execTransformFloat(Nd4jPointer *extraPointers,int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + auto tadShapeInfo = reinterpret_cast(extraPointers != nullptr ? extraPointers[10] : nullptr); auto tadOffsets = reinterpret_cast(extraPointers != nullptr ? extraPointers[11] : nullptr); LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execTransformFloat(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, - dZShapeInfo, extraParams, tadShapeInfo, tadOffsets); + NativeOpExecutioner::execTransformFloat(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + extraParams, + tadShapeInfo, tadOffsets); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1094,7 +1169,43 @@ Nd4jLong getDeviceTotalMemory(int device) { } int memcpySync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) { - return memcpyAsync(dst, src, size, flags, reserved); + cudaMemcpyKind kind; + + switch (flags) { + case 0: { + kind = cudaMemcpyHostToHost; + } + break; + case 1: { + kind = cudaMemcpyHostToDevice; + } + break; + case 2: { + kind = cudaMemcpyDeviceToHost; + } + break; + case 3: { + kind = cudaMemcpyDeviceToDevice; + } + break; + default: { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("UNDEFNED MEMCPY"); + return 0; + } + } + + auto dZ = cudaMemcpy(reinterpret_cast(dst), const_cast(reinterpret_cast(src)), static_cast(size), kind); + if (dZ != 0) { + printf("Failed on [%p] -> [%p], size: [%i], direction: [%i], dZ: [%i]\n", src, dst, size, flags, static_cast(dZ)); + fflush(stdout); + fflush(stderr); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaMemcpy failed"); + return 0; + } + + return 1; } int memcpyAsync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) { @@ -1131,11 +1242,12 @@ int memcpyAsync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4j auto dZ = cudaMemcpyAsync(reinterpret_cast(dst), const_cast(reinterpret_cast(src)), static_cast(size), kind, *pStream); //auto dZ = cudaMemcpy(reinterpret_cast(dst), const_cast(reinterpret_cast(src)), static_cast(size), kind); if (dZ != 0) { - printf("Failed on [%lu] -> [%lu], size: [%i], direction: [%i], dZ: [%i]\n", src, dst, size, flags, static_cast(dZ)); + printf("Failed on [%p] -> [%p], size: [%i], direction: [%i], dZ: [%i]\n", src, dst, size, flags, static_cast(dZ)); fflush(stdout); fflush(stderr); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaMemcpyAsync failed"); + return 0; } return 1; @@ -1348,10 +1460,8 @@ Nd4jPointer getConstantSpace() { } void pullRows(Nd4jPointer *extraPointers, - void *x, Nd4jLong *xShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *z, Nd4jLong *zShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *xShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *zShapeInfo, Nd4jLong *dZShapeInfo, Nd4jLong n, Nd4jLong *indexes, Nd4jLong *tadShapeInfo, @@ -1359,14 +1469,18 @@ void pullRows(Nd4jPointer *extraPointers, Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + cudaStream_t *stream = reinterpret_cast(extraPointers[1]); dim3 launchDims(64, 256, 1024); auto xType = nd4j::ArrayOptions::dataType(xShapeInfo); BUILD_SINGLE_SELECTOR(xType, pullRowsKernelGeneric, - (launchDims, stream, dX, dZ, n, indexes, tadShapeInfo, tadOffsets, zTadShapeInfo, zTadOffsets), + (launchDims, stream, dbX->special(), dbZ->special(), n, indexes, tadShapeInfo, tadOffsets, zTadShapeInfo, zTadOffsets), LIBND4J_TYPES); DEBUG_KERNEL(stream, -1); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1500,16 +1614,21 @@ void setTADThreshold(int num) { //////////////////////////////////////////////////////////////////////// void execSummaryStats(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, bool biasCorrected) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execSummaryStats(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, - hZShapeInfo, dZ, dZShapeInfo, biasCorrected); + NativeOpExecutioner::execSummaryStats(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + extraParams, + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + biasCorrected); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1519,22 +1638,29 @@ void execSummaryStats(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execSummaryStatsTad(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, void *dDimension, Nd4jLong *dDimensionShape, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, bool biasCorrected, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { try { - auto dimension = reinterpret_cast(dDimension); + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbDimension}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execSummaryStats(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, - hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadShapeInfo, - tadOffsets, biasCorrected); + NativeOpExecutioner::execSummaryStats(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + extraParams, + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + reinterpret_cast(dbDimension->special()), dimensionLength, + tadShapeInfo, tadOffsets, + biasCorrected); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbDimension}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1544,17 +1670,21 @@ void execSummaryStatsTad(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execReduce3(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo) { + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduce3(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, hYShapeInfo, dY, - dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo); + NativeOpExecutioner::execReduce3(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + extraParams, + dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1564,22 +1694,22 @@ void execReduce3(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execReduce3Tad(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, void *dDimension, Nd4jLong *dDimensionShape, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *yTadOnlyShapeInfo, Nd4jLong *yTadOffsets) { try { - auto dimension = reinterpret_cast(dDimension); + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, - reinterpret_cast(hDimension), + dimension, shape::length(hDimensionShape)); auto tadLength = shape::length(tadPack.primaryShapeInfo()); auto yLength = shape::length(hYShapeInfo); @@ -1589,16 +1719,23 @@ void execReduce3Tad(Nd4jPointer *extraPointers, if (tadLength == yLength || tadLength == xLength) { // nd4j_printf("== way\n",""); - NativeOpExecutioner::execReduce3(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, hYShapeInfo, - dY, - dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, - tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets); + NativeOpExecutioner::execReduce3(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + extraParams, + dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + dimension, dimensionLength, + tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets); } else - NativeOpExecutioner::execReduce3TAD(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, - hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, - dimension, dimensionLength, tadOnlyShapeInfo, yTadOffsets, - yTadOnlyShapeInfo, yTadOffsets); + NativeOpExecutioner::execReduce3TAD(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + extraParams, + dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + dimension, dimensionLength, + tadOnlyShapeInfo, yTadOffsets, yTadOnlyShapeInfo, yTadOffsets); + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1607,17 +1744,21 @@ void execReduce3Tad(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execReduce3Scalar(Nd4jPointer *extraPointers,int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *extraParams, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo) { + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + void *extraParams, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduce3Scalar(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, - hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo); + NativeOpExecutioner::execReduce3Scalar(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + extraParams, + dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1627,18 +1768,21 @@ void execReduce3Scalar(Nd4jPointer *extraPointers,int opNum, //////////////////////////////////////////////////////////////////////// void execScalarBool(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hScalar, Nd4jLong *hScalarShapeInfo, - void *dScalar, Nd4jLong *dScalarShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbScalar, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo, void *extraParams) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalar}); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execScalarBool(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, - dZShapeInfo, hScalar, hScalarShapeInfo, dScalar, dScalarShapeInfo, - extraParams); + NativeOpExecutioner::execScalarBool(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + dbScalar->primary(), hScalarShapeInfo, dbScalar->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hScalarShapeInfo).specialAsT(), + extraParams); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbScalar}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1648,25 +1792,30 @@ void execScalarBool(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execScalarBoolTad(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hScalars, Nd4jLong *hScalarShapeInfo, - void *dScalars, Nd4jLong *dScalarShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbScalars, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo, void *extraParams, - void *hDimension, Nd4jLong *hDimensionShape, void *dDimension, Nd4jLong *dDimensionShape, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { try { - auto dimension = reinterpret_cast(dDimension); + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalars}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execScalarBool(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, - dZ, dZShapeInfo, hScalars, hScalarShapeInfo, dScalars, dScalarShapeInfo, - dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, - tadOffsetsZ); + NativeOpExecutioner::execScalarBool(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + extraParams, + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + dbScalars->primary(), hScalarShapeInfo, dbScalars->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hScalarShapeInfo).specialAsT(), + dimension, dimensionLength, + tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbScalars}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1676,17 +1825,21 @@ void execScalarBoolTad(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execScalar(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hScalar, Nd4jLong *hScalarShapeInfo, - void *dScalar, Nd4jLong *dScalarShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbScalar, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo, void *extraParams) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalar}); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execScalar(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, - hScalar, hScalarShapeInfo, dScalar, dScalarShapeInfo, extraParams); + NativeOpExecutioner::execScalar(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + dbScalar->primary(), hScalarShapeInfo, dbScalar->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hScalarShapeInfo).specialAsT(), + extraParams); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbScalar}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1696,19 +1849,18 @@ void execScalar(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execScalarTad(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hScalars, Nd4jLong *hScalarShapeInfo, - void *dScalars, Nd4jLong *dScalarShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbScalars, Nd4jLong *hScalarShapeInfo, Nd4jLong *dScalarShapeInfo, void *extraParams, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { try { - auto dimension = reinterpret_cast(dDimension); + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalars}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); cudaStream_t *stream = reinterpret_cast(extraPointers[1]); @@ -1725,10 +1877,12 @@ void execScalarTad(Nd4jPointer *extraPointers, #ifdef __ND4J_EXPERIMENTAL__ BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalars, extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES); #else - BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalars, extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), dbScalars->special(), extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES); #endif DEBUG_KERNEL(stream, opNum); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbScalars}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1777,12 +1931,17 @@ void execAggregateBatch(Nd4jPointer *extraPointers, void execRandom(Nd4jPointer *extraPointers, int opNum, Nd4jPointer stateHost, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraArguments) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {}); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execRandom(&lc, opNum, stateHost, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments); + NativeOpExecutioner::execRandom(&lc, opNum, stateHost, + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + extraArguments); + + InteropDataBuffer::registerSpecialUse({dbZ}, {}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1791,15 +1950,19 @@ void execRandom(Nd4jPointer *extraPointers, //////////////////////////////////////////////////////////////////////// void execRandom2(Nd4jPointer *extraPointers, int opNum, Nd4jPointer stateHost, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraArguments) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execRandom(&lc, opNum, stateHost, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, - dZShapeInfo, extraArguments); + NativeOpExecutioner::execRandom(&lc, opNum, stateHost, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + extraArguments); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1808,17 +1971,21 @@ void execRandom2(Nd4jPointer *extraPointers, int opNum, Nd4jPointer stateHost, //////////////////////////////////////////////////////////////////////// void execRandom3(Nd4jPointer *extraPointers, int opNum, Nd4jPointer stateHost, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraArguments) { try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execRandom(&lc, opNum, stateHost, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY, - dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments); + NativeOpExecutioner::execRandom(&lc, opNum, stateHost, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + extraArguments); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -1924,21 +2091,24 @@ Nd4jPointer pointerForAddress(Nd4jLong address) { } void tear(Nd4jPointer *extras, - void *x, Nd4jLong *xShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *xShapeInfo, Nd4jLong *dXShapeInfo, Nd4jPointer *targets, Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { try { + InteropDataBuffer::prepareSpecialUse({}, {dbX}); + cudaStream_t *stream = reinterpret_cast(extras[1]); dim3 launchDims(512, 512, 512); auto xType = nd4j::ArrayOptions::dataType(xShapeInfo); BUILD_SINGLE_SELECTOR(xType, tearKernelGeneric, - (launchDims, stream, dX, dXShapeInfo, targets, zShapeInfo, tadShapeInfo, tadOffsets), + (launchDims, stream, dbX->special(), dXShapeInfo, targets, zShapeInfo, tadShapeInfo, tadOffsets), LIBND4J_TYPES); nd4j::DebugHelper::checkErrorCode(stream, "tearFloat(...) failed"); + + InteropDataBuffer::registerSpecialUse({}, {dbX}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -2100,25 +2270,30 @@ void decodeThreshold(Nd4jPointer *extraPointers, void *dx, Nd4jLong N, void *dz, //////////////////////////////////////////////////////////////////////// void execReduce3All(Nd4jPointer *extraPointers, int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParamsVals, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape, + OpaqueDataBuffer *dbY, Nd4jLong *hYShapeInfo, Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape, Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) { try { - auto dimension = reinterpret_cast(dDimension); + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY, dbDimension}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); int dimensionLength = static_cast(shape::length(hDimensionShape)); LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduce3All(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParamsVals, hY, - hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, - dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets); + NativeOpExecutioner::execReduce3All(&lc, opNum, + dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), + extraParamsVals, + dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), + reinterpret_cast(dbDimension->special()), dimensionLength, + xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -3384,27 +3559,43 @@ nd4j::graph::Context* createGraphContext(int nodeId) { nd4j::graph::RandomGenerator* getGraphContextRandomGenerator(nd4j::graph::Context* ptr) { return &ptr->randomGenerator(); } + void markGraphContextInplace(nd4j::graph::Context* ptr, bool reallyInplace) { ptr->markInplace(reallyInplace); } + void setGraphContextCudaContext(nd4j::graph::Context* ptr, void *stream, void *reductionPointer, void *allocationPointer) { ptr->setCudaContext(stream, reductionPointer, allocationPointer); } + void setGraphContextInputArray(nd4j::graph::Context* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) { ptr->setInputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo); } + void setGraphContextOutputArray(nd4j::graph::Context* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) { ptr->setOutputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo); } + +void setGraphContextInputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo) { + ptr->setInputArray(index, buffer, shapeInfo, specialShapeInfo); +} + +void setGraphContextOutputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo) { + ptr->setOutputArray(index, buffer, shapeInfo, specialShapeInfo); +} + void setGraphContextTArguments(nd4j::graph::Context* ptr, double *arguments, int numberOfArguments) { ptr->setTArguments(arguments, numberOfArguments); } + void setGraphContextIArguments(nd4j::graph::Context* ptr, Nd4jLong *arguments, int numberOfArguments) { ptr->setIArguments(arguments, numberOfArguments); } + void setGraphContextBArguments(nd4j::graph::Context* ptr, bool *arguments, int numberOfArguments) { ptr->setBArguments(arguments, numberOfArguments); } + void deleteGraphContext(nd4j::graph::Context* ptr) { delete ptr; } @@ -3581,4 +3772,97 @@ bool isOptimalRequirementsMet() { void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow) { ptr->allowHelpers(reallyAllow); +} + +OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth) { + auto dtype = DataTypeUtils::fromInt(dataType); + return new nd4j::InteropDataBuffer(elements * DataTypeUtils::sizeOf(dtype) , dtype, allocateBoth); +} + +Nd4jPointer dbPrimaryBuffer(OpaqueDataBuffer *dataBuffer) { + return dataBuffer->primary(); +} + +Nd4jPointer dbSpecialBuffer(OpaqueDataBuffer *dataBuffer) { + return dataBuffer->special(); +} + +void deleteDataBuffer(OpaqueDataBuffer *dataBuffer) { + delete dataBuffer; +} + +void dbSetPrimaryBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer primaryBuffer, Nd4jLong numBytes) { + dataBuffer->setPrimary(primaryBuffer, numBytes); +} + +void dbSetSpecialBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer specialBuffer, Nd4jLong numBytes) { + dataBuffer->setSpecial(specialBuffer, numBytes); +} + +void dbAllocatePrimaryBuffer(OpaqueDataBuffer *dataBuffer) { + dataBuffer->dataBuffer()->allocatePrimary(); +} + +void dbAllocateSpecialBuffer(OpaqueDataBuffer *dataBuffer) { + dataBuffer->dataBuffer()->allocateSpecial(); +} + +void dbExpandBuffer(OpaqueDataBuffer *dataBuffer, Nd4jLong elements) { + dataBuffer->dataBuffer()->expand(elements * DataTypeUtils::sizeOf(dataBuffer->dataBuffer()->getDataType())); +} + +OpaqueDataBuffer* dbCreateView(OpaqueDataBuffer *dataBuffer, Nd4jLong length, Nd4jLong offset) { + return new InteropDataBuffer(*dataBuffer, length, offset); +} + +void dbSyncToSpecial(OpaqueDataBuffer *dataBuffer) { + dataBuffer->dataBuffer()->syncToSpecial(); +} + +void dbSyncToPrimary(OpaqueDataBuffer *dataBuffer) { + dataBuffer->dataBuffer()->syncToPrimary(nullptr); +} + +void dbTickHostRead(OpaqueDataBuffer *dataBuffer) { + dataBuffer->dataBuffer()->readPrimary(); +} + +void dbTickHostWrite(OpaqueDataBuffer *dataBuffer) { + dataBuffer->dataBuffer()->writePrimary(); +} + +void dbTickDeviceRead(OpaqueDataBuffer *dataBuffer) { + dataBuffer->dataBuffer()->readSpecial(); +} + +void dbTickDeviceWrite(OpaqueDataBuffer *dataBuffer) { + dataBuffer->dataBuffer()->writeSpecial(); +} + +void dbExpand(OpaqueDataBuffer *dataBuffer, Nd4jLong elements) { + dataBuffer->expand(elements); +} + +void dbClose(OpaqueDataBuffer *dataBuffer) { + dataBuffer->getDataBuffer()->close(); +} + +int dbDeviceId(OpaqueDataBuffer *dataBuffer) { + return dataBuffer->deviceId(); +} + +void dbSetDeviceId(OpaqueDataBuffer *dataBuffer, int deviceId) { + dataBuffer->setDeviceId(deviceId); +} + +int dbLocality(OpaqueDataBuffer *dataBuffer) { + auto p = dataBuffer->dataBuffer()->isPrimaryActual(); + auto d = dataBuffer->dataBuffer()->isSpecialActual(); + + if (p && d) + return 0; + else if (p) + return -1; + else + return 1; } \ No newline at end of file diff --git a/libnd4j/include/array/ArrayOptions.h b/libnd4j/include/array/ArrayOptions.h index a753be1bf..484228fb7 100644 --- a/libnd4j/include/array/ArrayOptions.h +++ b/libnd4j/include/array/ArrayOptions.h @@ -34,10 +34,12 @@ #define ARRAY_SPARSE 2 #define ARRAY_COMPRESSED 4 #define ARRAY_EMPTY 8 +#define ARRAY_RAGGED 16 -#define ARRAY_CSR 16 -#define ARRAY_CSC 32 -#define ARRAY_COO 64 + +#define ARRAY_CSR 32 +#define ARRAY_CSC 64 +#define ARRAY_COO 128 // complex values #define ARRAY_COMPLEX 512 @@ -72,8 +74,10 @@ // boolean values #define ARRAY_BOOL 524288 -// utf-8 values -#define ARRAY_STRING 1048576 +// UTF values +#define ARRAY_UTF8 1048576 +#define ARRAY_UTF16 4194304 +#define ARRAY_UTF32 16777216 // flag for extras #define ARRAY_EXTRAS 2097152 @@ -173,8 +177,12 @@ namespace nd4j { return nd4j::DataType ::UINT32; else if (hasPropertyBitSet(shapeInfo, ARRAY_LONG)) return nd4j::DataType ::UINT64; - else if (hasPropertyBitSet(shapeInfo, ARRAY_STRING)) + else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF8)) return nd4j::DataType ::UTF8; + else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF16)) + return nd4j::DataType ::UTF16; + else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF32)) + return nd4j::DataType ::UTF32; else { //shape::printShapeInfoLinear("Bad unsigned datatype (not)stored in shape", const_cast(shapeInfo)); #ifndef __CUDA_ARCH__ @@ -190,8 +198,12 @@ namespace nd4j { return nd4j::DataType::INT32; else if (hasPropertyBitSet(shapeInfo, ARRAY_LONG)) return nd4j::DataType::INT64; - else if (hasPropertyBitSet(shapeInfo, ARRAY_STRING)) + else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF8)) return nd4j::DataType::UTF8; + else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF16)) + return nd4j::DataType::UTF16; + else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF32)) + return nd4j::DataType::UTF32; else { //shape::printShapeInfoLinear("Bad signed datatype (not)stored in shape", const_cast(shapeInfo)); #ifndef __CUDA_ARCH__ @@ -224,6 +236,8 @@ namespace nd4j { return ArrayType::COMPRESSED; else if (hasPropertyBitSet(shapeInfo, ARRAY_EMPTY)) return ArrayType::EMPTY; + else if (hasPropertyBitSet(shapeInfo, ARRAY_RAGGED)) + return ArrayType::RAGGED; else // by default we return DENSE type here return ArrayType::DENSE; } @@ -333,7 +347,13 @@ namespace nd4j { setPropertyBit(shapeInfo, ARRAY_LONG); break; case nd4j::DataType::UTF8: - setPropertyBit(shapeInfo, ARRAY_STRING); + setPropertyBit(shapeInfo, ARRAY_UTF8); + break; + case nd4j::DataType::UTF16: + setPropertyBit(shapeInfo, ARRAY_UTF16); + break; + case nd4j::DataType::UTF32: + setPropertyBit(shapeInfo, ARRAY_UTF32); break; default: #ifndef __CUDA_ARCH__ diff --git a/libnd4j/include/array/ArrayType.h b/libnd4j/include/array/ArrayType.h index 2300bf841..d4d6c9729 100644 --- a/libnd4j/include/array/ArrayType.h +++ b/libnd4j/include/array/ArrayType.h @@ -27,6 +27,7 @@ namespace nd4j { SPARSE = 2, COMPRESSED = 3, EMPTY = 4, + RAGGED = 5, }; } diff --git a/libnd4j/include/array/DataBuffer.h b/libnd4j/include/array/DataBuffer.h index 034f16a25..cd27c20b8 100644 --- a/libnd4j/include/array/DataBuffer.h +++ b/libnd4j/include/array/DataBuffer.h @@ -36,13 +36,14 @@ class ND4J_EXPORT DataBuffer { private: - void* _primaryBuffer; - void* _specialBuffer; - size_t _lenInBytes; + void* _primaryBuffer = nullptr; + void* _specialBuffer = nullptr; + size_t _lenInBytes = 0; DataType _dataType; - memory::Workspace* _workspace; + memory::Workspace* _workspace = nullptr; bool _isOwnerPrimary; bool _isOwnerSpecial; + std::atomic _deviceId; #ifdef __CUDABLAS__ mutable std::atomic _counter; @@ -52,51 +53,52 @@ class ND4J_EXPORT DataBuffer { mutable std::atomic _readSpecial; #endif - void setCountersToZero(); - void copyCounters(const DataBuffer& other); - void deleteSpecial(); - FORCEINLINE void deletePrimary(); - FORCEINLINE void deleteBuffers(); - FORCEINLINE void setAllocFlags(const bool isOwnerPrimary, const bool isOwnerSpecial = false); - void allocateBuffers(const bool allocBoth = false); - void setSpecial(void* special, const bool isOwnerSpecial); - void copyBufferFromHost(const void* hostBuffer, size_t sizeToCopyinBytes = 0, const Nd4jLong offsetThis = 0, const Nd4jLong offsetHostBuffer = 0); + void setCountersToZero(); + void copyCounters(const DataBuffer& other); + void deleteSpecial(); + void deletePrimary(); + void deleteBuffers(); + void setAllocFlags(const bool isOwnerPrimary, const bool isOwnerSpecial = false); + void allocateBuffers(const bool allocBoth = false); + void setSpecial(void* special, const bool isOwnerSpecial); + void copyBufferFromHost(const void* hostBuffer, size_t sizeToCopyinBytes = 0, const Nd4jLong offsetThis = 0, const Nd4jLong offsetHostBuffer = 0); public: - FORCEINLINE DataBuffer(void* primary, void* special, + DataBuffer(void* primary, void* special, const size_t lenInBytes, const DataType dataType, const bool isOwnerPrimary = false, const bool isOwnerSpecial = false, memory::Workspace* workspace = nullptr); - FORCEINLINE DataBuffer(void* primary, + DataBuffer(void* primary, const size_t lenInBytes, const DataType dataType, const bool isOwnerPrimary = false, memory::Workspace* workspace = nullptr); - FORCEINLINE DataBuffer(const void* hostBuffer, // copies data from hostBuffer to own memory buffer + DataBuffer(const void* hostBuffer, // copies data from hostBuffer to own memory buffer const DataType dataType, const size_t lenInBytes, memory::Workspace* workspace = nullptr); - FORCEINLINE DataBuffer(const size_t lenInBytes, const DataType dataType, memory::Workspace* workspace = nullptr, const bool allocBoth = false); + DataBuffer(const size_t lenInBytes, const DataType dataType, memory::Workspace* workspace = nullptr, const bool allocBoth = false); - FORCEINLINE DataBuffer(const DataBuffer& other); - FORCEINLINE DataBuffer(DataBuffer&& other); - FORCEINLINE explicit DataBuffer(); - FORCEINLINE ~DataBuffer(); + DataBuffer(const DataBuffer& other); + DataBuffer(DataBuffer&& other); + explicit DataBuffer(); + ~DataBuffer(); - FORCEINLINE DataBuffer& operator=(const DataBuffer& other); - FORCEINLINE DataBuffer& operator=(DataBuffer&& other) noexcept; + DataBuffer& operator=(const DataBuffer& other); + DataBuffer& operator=(DataBuffer&& other) noexcept; - FORCEINLINE DataType getDataType(); - FORCEINLINE size_t getLenInBytes() const; + DataType getDataType(); + void setDataType(DataType dataType); + size_t getLenInBytes() const; - FORCEINLINE void* primary(); - FORCEINLINE void* special(); + void* primary(); + void* special(); - FORCEINLINE void allocatePrimary(); - void allocateSpecial(); + void allocatePrimary(); + void allocateSpecial(); void writePrimary() const; void writeSpecial() const; @@ -105,6 +107,10 @@ class ND4J_EXPORT DataBuffer { bool isPrimaryActual() const; bool isSpecialActual() const; + void expand(const uint64_t size); + + int deviceId() const; + void setDeviceId(int deviceId); void migrate(); template FORCEINLINE T* primaryAsT(); @@ -118,256 +124,28 @@ class ND4J_EXPORT DataBuffer { void copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes = 0, const Nd4jLong offsetThis = 0, const Nd4jLong offsetOther = 0); static void memcpy(const DataBuffer &dst, const DataBuffer &src); + + void setPrimaryBuffer(void *buffer, size_t length); + void setSpecialBuffer(void *buffer, size_t length); + + /** + * This method deletes buffers, if we're owners + */ + void close(); }; - - - ///// IMLEMENTATION OF INLINE METHODS ///// - //////////////////////////////////////////////////////////////////////// -// default constructor -DataBuffer::DataBuffer() { - - _primaryBuffer = nullptr; - _specialBuffer = nullptr; - _lenInBytes = 0; - _dataType = INT8; - _workspace = nullptr; - _isOwnerPrimary = false; - _isOwnerSpecial = false; - - setCountersToZero(); -} - -//////////////////////////////////////////////////////////////////////// -// copy constructor -DataBuffer::DataBuffer(const DataBuffer &other) { - - throw std::runtime_error("DataBuffer copy constructor: we don't expect using of this constructor!"); - - _lenInBytes = other._lenInBytes; - _dataType = other._dataType; - _workspace = other._workspace; - - _primaryBuffer = nullptr; - _specialBuffer = nullptr; - - setCountersToZero(); - - allocateBuffers(); - copyBufferFrom(other); -} - -//////////////////////////////////////////////////////////////////////// -DataBuffer::DataBuffer(void* primary, void* special, - const size_t lenInBytes, const DataType dataType, - const bool isOwnerPrimary, const bool isOwnerSpecial, - memory::Workspace* workspace) { - - if (primary == nullptr && special == nullptr) - throw std::runtime_error("DataBuffer constructor: can't be initialized with both nullptr buffers !"); - - _primaryBuffer = primary; - _specialBuffer = special; - _lenInBytes = lenInBytes; - _dataType = dataType; - _workspace = workspace; - _isOwnerPrimary = isOwnerPrimary; - _isOwnerSpecial = isOwnerSpecial; - - setCountersToZero(); - - if(primary != nullptr) - readPrimary(); - if(special != nullptr) - readSpecial(); -} - -//////////////////////////////////////////////////////////////////////// -DataBuffer::DataBuffer(void* primary, const size_t lenInBytes, const DataType dataType, const bool isOwnerPrimary, memory::Workspace* workspace): - DataBuffer(primary, nullptr, lenInBytes, dataType, isOwnerPrimary, false, workspace) { - - syncToSpecial(true); -} - -//////////////////////////////////////////////////////////////////////// -// copies data from hostBuffer to own memory buffer -DataBuffer::DataBuffer(const void* hostBuffer, const DataType dataType, const size_t lenInBytes, memory::Workspace* workspace) { - - if (hostBuffer == nullptr) - throw std::runtime_error("DataBuffer constructor: can't be initialized with nullptr host buffer !"); - if (lenInBytes == 0) - throw std::runtime_error("DataBuffer constructor: can't be initialized with zero length !"); - - _primaryBuffer = nullptr; - _specialBuffer = nullptr; - _lenInBytes = lenInBytes; - _dataType = dataType; - _workspace = workspace; - - setCountersToZero(); - - allocateBuffers(); - - copyBufferFromHost(hostBuffer, lenInBytes); -} - -//////////////////////////////////////////////////////////////////////// -DataBuffer::DataBuffer(const size_t lenInBytes, const DataType dataType, memory::Workspace* workspace, const bool allocBoth) { - - _dataType = dataType; - _workspace = workspace; - _lenInBytes = lenInBytes; - - _primaryBuffer = nullptr; - _specialBuffer = nullptr; - - setCountersToZero(); - - if(lenInBytes != 0) { - allocateBuffers(allocBoth); - writeSpecial(); + template + T* DataBuffer::primaryAsT() { + return reinterpret_cast(_primaryBuffer); } -} //////////////////////////////////////////////////////////////////////// -// move constructor -DataBuffer::DataBuffer(DataBuffer&& other) { - - _primaryBuffer = other._primaryBuffer; - _specialBuffer = other._specialBuffer; - _lenInBytes = other._lenInBytes; - _dataType = other._dataType; - _workspace = other._workspace; - _isOwnerPrimary = other._isOwnerPrimary; - _isOwnerSpecial = other._isOwnerSpecial; - - copyCounters(other); - - other._primaryBuffer = other._specialBuffer = nullptr; - other.setAllocFlags(false, false); - other._lenInBytes = 0; -} - -//////////////////////////////////////////////////////////////////////// -// assignment operator -DataBuffer& DataBuffer::operator=(const DataBuffer& other) { - - if (this == &other) - return *this; - - deleteBuffers(); - - _lenInBytes = other._lenInBytes; - _dataType = other._dataType; - _workspace = other._workspace; - - allocateBuffers(); - copyBufferFrom(other); - - return *this; -} - -//////////////////////////////////////////////////////////////////////// -// move assignment operator -DataBuffer& DataBuffer::operator=(DataBuffer&& other) noexcept { - - if (this == &other) - return *this; - - deleteBuffers(); - - _primaryBuffer = other._primaryBuffer; - _specialBuffer = other._specialBuffer; - _lenInBytes = other._lenInBytes; - _dataType = other._dataType; - _workspace = other._workspace; - _isOwnerPrimary = other._isOwnerPrimary; - _isOwnerSpecial = other._isOwnerSpecial; - - copyCounters(other); - - other._primaryBuffer = other._specialBuffer = nullptr; - other.setAllocFlags(false, false); - other._lenInBytes = 0; - - return *this; -} - -//////////////////////////////////////////////////////////////////////// -void* DataBuffer::primary() { - return _primaryBuffer; -} - -//////////////////////////////////////////////////////////////////////// -void* DataBuffer::special() { - return _specialBuffer; -} - -//////////////////////////////////////////////////////////////////////// -DataType DataBuffer::getDataType() { - return _dataType; -} - -//////////////////////////////////////////////////////////////////////// -size_t DataBuffer::getLenInBytes() const { - return _lenInBytes; -} - -//////////////////////////////////////////////////////////////////////// -template -T* DataBuffer::primaryAsT() { - return reinterpret_cast(_primaryBuffer); -} - -//////////////////////////////////////////////////////////////////////// -template -T* DataBuffer::specialAsT() { - return reinterpret_cast(_specialBuffer); -} - -//////////////////////////////////////////////////////////////////////// -void DataBuffer::allocatePrimary() { - - if (_primaryBuffer == nullptr && getLenInBytes() > 0) { - ALLOCATE(_primaryBuffer, _workspace, getLenInBytes(), int8_t); - _isOwnerPrimary = true; + template + T* DataBuffer::specialAsT() { + return reinterpret_cast(_specialBuffer); } -} - -//////////////////////////////////////////////////////////////////////// -void DataBuffer::setAllocFlags(const bool isOwnerPrimary, const bool isOwnerSpecial) { - - _isOwnerPrimary = isOwnerPrimary; - _isOwnerSpecial = isOwnerSpecial; -} - -//////////////////////////////////////////////////////////////////////// -void DataBuffer::deletePrimary() { - - if(_isOwnerPrimary && _primaryBuffer != nullptr && getLenInBytes() != 0) { - auto p = reinterpret_cast(_primaryBuffer); - RELEASE(p, _workspace); - _primaryBuffer = nullptr; - _isOwnerPrimary = false; - } -} - -//////////////////////////////////////////////////////////////////////// -void DataBuffer::deleteBuffers() { - - deletePrimary(); - deleteSpecial(); - _lenInBytes = 0; -} - -//////////////////////////////////////////////////////////////////////// -DataBuffer::~DataBuffer() { - - deleteBuffers(); -} - } diff --git a/libnd4j/include/array/DataType.h b/libnd4j/include/array/DataType.h index b3e21840d..8ec55342e 100644 --- a/libnd4j/include/array/DataType.h +++ b/libnd4j/include/array/DataType.h @@ -42,6 +42,8 @@ namespace nd4j { QINT16 = 16, BFLOAT16 = 17, UTF8 = 50, + UTF16 = 51, + UTF32 = 52, ANY = 100, AUTO = 200, }; diff --git a/libnd4j/include/array/InteropDataBuffer.h b/libnd4j/include/array/InteropDataBuffer.h new file mode 100644 index 000000000..3cbfc2f94 --- /dev/null +++ b/libnd4j/include/array/InteropDataBuffer.h @@ -0,0 +1,71 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include +#include +#include + +#ifndef LIBND4J_INTEROPDATABUFFER_H +#define LIBND4J_INTEROPDATABUFFER_H + +namespace nd4j { + /** + * This class is a wrapper for DataBuffer, suitable for sharing DataBuffer between front-end and back-end languages + */ + class ND4J_EXPORT InteropDataBuffer { + private: + std::shared_ptr _dataBuffer; + uint64_t _offset = 0; + public: + InteropDataBuffer(InteropDataBuffer &dataBuffer, uint64_t length, uint64_t offset); + InteropDataBuffer(std::shared_ptr databuffer); + InteropDataBuffer(size_t elements, nd4j::DataType dtype, bool allocateBoth); + ~InteropDataBuffer() = default; + +#ifndef __JAVACPP_HACK__ + std::shared_ptr getDataBuffer() const; + std::shared_ptr dataBuffer(); +#endif + + void* primary() const; + void* special() const; + + uint64_t offset() const ; + void setOffset(uint64_t offset); + + void setPrimary(void* ptr, size_t length); + void setSpecial(void* ptr, size_t length); + + void expand(size_t newlength); + + int deviceId() const; + void setDeviceId(int deviceId); + + static void registerSpecialUse(const std::vector& writeList, const std::vector& readList); + static void prepareSpecialUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables = false); + + static void registerPrimaryUse(const std::vector& writeList, const std::vector& readList); + static void preparePrimaryUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables = false); + }; +} + + +#endif //LIBND4J_INTEROPDATABUFFER_H diff --git a/libnd4j/include/array/cpu/DataBuffer.cpp b/libnd4j/include/array/cpu/DataBuffer.cpp index d13ca0def..ccd782adc 100644 --- a/libnd4j/include/array/cpu/DataBuffer.cpp +++ b/libnd4j/include/array/cpu/DataBuffer.cpp @@ -23,6 +23,24 @@ #include namespace nd4j { + void DataBuffer::expand(const uint64_t size) { + if (size > _lenInBytes) { + // allocate new buffer + int8_t *newBuffer = nullptr; + ALLOCATE(newBuffer, _workspace, size, int8_t); + + // copy data from existing buffer + std::memcpy(newBuffer, _primaryBuffer, _lenInBytes); + + if (_isOwnerPrimary) { + RELEASE(reinterpret_cast(_primaryBuffer), _workspace); + } + + _primaryBuffer = newBuffer; + _lenInBytes = size; + _isOwnerPrimary = true; + } + } //////////////////////////////////////////////////////////////////////// void DataBuffer::setCountersToZero() { @@ -99,14 +117,17 @@ void DataBuffer::allocateSpecial() { void DataBuffer::migrate() { } -/////////////////////////////////////////////////////////////////////// -void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) { - if (src._lenInBytes < dst._lenInBytes) - throw std::runtime_error("DataBuffer::memcpy: Source data buffer is smaller than destination"); - std::memcpy(dst._primaryBuffer, src._primaryBuffer, dst._lenInBytes); +///////////////////////// +void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) { + if (src._lenInBytes > dst._lenInBytes) + throw std::runtime_error("DataBuffer::memcpy: Source data buffer is larger than destination"); + + std::memcpy(dst._primaryBuffer, src._primaryBuffer, src._lenInBytes); + dst.readPrimary(); } + //////////////////////////////////////////////////////////////////////// void DataBuffer::writePrimary() const { } void DataBuffer::writeSpecial() const { } diff --git a/libnd4j/include/array/cuda/DataBuffer.cu b/libnd4j/include/array/cuda/DataBuffer.cu index 5cb227e69..2a3efa3c8 100644 --- a/libnd4j/include/array/cuda/DataBuffer.cu +++ b/libnd4j/include/array/cuda/DataBuffer.cu @@ -25,6 +25,40 @@ #include namespace nd4j { + void DataBuffer::expand(const uint64_t size) { + if (size > _lenInBytes) { + // allocate new buffer + int8_t *newBuffer = nullptr; + int8_t *newSpecialBuffer = nullptr; + ALLOCATE_SPECIAL(newSpecialBuffer, _workspace, size, int8_t); + + // copy data from existing buffer + if (_primaryBuffer != nullptr) { + // there's non-zero chance that primary buffer doesn't exist yet + ALLOCATE(newBuffer, _workspace, size, int8_t); + std::memcpy(newBuffer, _primaryBuffer, _lenInBytes); + + if (_isOwnerPrimary) { + auto ipb = reinterpret_cast(_primaryBuffer); + RELEASE(ipb, _workspace); + } + + _primaryBuffer = newBuffer; + _isOwnerPrimary = true; + } + + cudaMemcpy(newSpecialBuffer, _specialBuffer, _lenInBytes, cudaMemcpyDeviceToDevice); + + if (_isOwnerSpecial) { + auto isb = reinterpret_cast(_specialBuffer); + RELEASE_SPECIAL(isb, _workspace); + } + + _specialBuffer = newSpecialBuffer; + _lenInBytes = size; + _isOwnerSpecial = true; + } + } //////////////////////////////////////////////////////////////////////// void DataBuffer::allocateSpecial() { @@ -37,8 +71,9 @@ void DataBuffer::allocateSpecial() { //////////////////////////////////////////////////////////////////////// void DataBuffer::syncToPrimary(const LaunchContext* context, const bool forceSync) { - if(isPrimaryActual() && !forceSync) + if(isPrimaryActual() && !forceSync) { return; + } allocatePrimary(); @@ -46,7 +81,9 @@ void DataBuffer::syncToPrimary(const LaunchContext* context, const bool forceSyn if (res != 0) throw cuda_exception::build("DataBuffer::syncToPrimary failed to to some previous kernel failre", res); - cudaMemcpy(_primaryBuffer, _specialBuffer, getLenInBytes(), cudaMemcpyDeviceToHost); + res = cudaMemcpy(_primaryBuffer, _specialBuffer, getLenInBytes(), cudaMemcpyDeviceToHost); + if (res != 0) + throw cuda_exception::build("DataBuffer::syncToPrimary cudaMemcpy failed", res); readPrimary(); } @@ -54,13 +91,19 @@ void DataBuffer::syncToPrimary(const LaunchContext* context, const bool forceSyn //////////////////////////////////////////////////////////////////////// void DataBuffer::syncToSpecial(const bool forceSync) { - - if(isSpecialActual() && !forceSync) + // in this case there's nothing to do here + if (_primaryBuffer == nullptr) return; + if(isSpecialActual() && !forceSync) { + return; + } + allocateSpecial(); - cudaMemcpy(_specialBuffer, _primaryBuffer, getLenInBytes(), cudaMemcpyHostToDevice); + auto res = cudaMemcpy(_specialBuffer, _primaryBuffer, getLenInBytes(), cudaMemcpyHostToDevice); + if (res != 0) + throw cuda_exception::build("DataBuffer::syncToSpecial cudaMemcpy failed", res); readSpecial(); } @@ -97,19 +140,6 @@ void DataBuffer::copyCounters(const DataBuffer& other) { _readPrimary.store(other._writeSpecial); _readSpecial.store(other._writePrimary); } -//////////////////////////////////////////////////////////////////////// -void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) { - if (src._lenInBytes < dst._lenInBytes) - throw std::runtime_error("DataBuffer::memcpy: Source data buffer is smaller than destination"); - - if (src.isSpecialActual()) { - cudaMemcpy(dst._specialBuffer, src._specialBuffer, dst.getLenInBytes(), cudaMemcpyDeviceToDevice); - } else if (src.isPrimaryActual()) { - cudaMemcpy(dst._specialBuffer, src._primaryBuffer, dst.getLenInBytes(), cudaMemcpyHostToDevice); - } - - dst.writeSpecial(); -} //////////////////////////////////////////////////////////////////////// void DataBuffer::copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes, const Nd4jLong offsetThis, const Nd4jLong offsetOther) { // copies only to special buffer @@ -176,8 +206,11 @@ void DataBuffer::allocateBuffers(const bool allocBoth) { // always allocate s //////////////////////////////////////////////////////////////////////// void DataBuffer::setToZeroBuffers(const bool both) { + cudaMemsetAsync(special(), 0, getLenInBytes(), *LaunchContext::defaultContext()->getCudaStream()); + auto res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaStream()); + if (res != 0) + throw cuda_exception::build("DataBuffer::setToZeroBuffers: streamSync failed!", res); - cudaMemset(special(), 0, getLenInBytes()); writeSpecial(); if(both) { @@ -186,12 +219,37 @@ void DataBuffer::setToZeroBuffers(const bool both) { } } +///////////////////////// +void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) { + if (src._lenInBytes > dst._lenInBytes) + throw std::runtime_error("DataBuffer::memcpy: Source data buffer is larger than destination"); + + + int res = 0; + if (src.isSpecialActual()) { + res = cudaMemcpyAsync(dst._specialBuffer, src._specialBuffer, src.getLenInBytes(), cudaMemcpyDeviceToDevice, *LaunchContext::defaultContext()->getCudaStream()); + } else if (src.isPrimaryActual()) { + res = cudaMemcpyAsync(dst._specialBuffer, src._primaryBuffer, src.getLenInBytes(), cudaMemcpyHostToDevice, *LaunchContext::defaultContext()->getCudaStream()); + } + + if (res != 0) + throw cuda_exception::build("DataBuffer::memcpy: cudaMemcpyAsync failed!", res); + + res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaStream()); + if (res != 0) + throw cuda_exception::build("DataBuffer::memcpy: streamSync failed!", res); + + dst.writeSpecial(); +} + //////////////////////////////////////////////////////////////////////// void DataBuffer::migrate() { memory::Workspace* newWorkspace = nullptr; void* newBuffer; ALLOCATE_SPECIAL(newBuffer, newWorkspace, getLenInBytes(), int8_t); - cudaMemcpy(newBuffer, _specialBuffer, getLenInBytes(), cudaMemcpyDeviceToDevice); + auto res = cudaMemcpy(newBuffer, _specialBuffer, getLenInBytes(), cudaMemcpyDeviceToDevice); + if (res != 0) + throw cuda_exception::build("DataBuffer::migrate: cudaMemcpyAsync failed!", res); if (_isOwnerSpecial) { // now we're releasing original buffer @@ -203,7 +261,7 @@ void DataBuffer::migrate() { } //////////////////////////////////////////////////////////////////////// -void DataBuffer::writePrimary() const { _writePrimary = ++_counter; } +void DataBuffer::writePrimary() const {_writePrimary = ++_counter; } void DataBuffer::writeSpecial() const { _writeSpecial = ++_counter; } void DataBuffer::readPrimary() const { _readPrimary = ++_counter; } void DataBuffer::readSpecial() const { _readSpecial = ++_counter; } diff --git a/libnd4j/include/array/impl/DataBuffer.cpp b/libnd4j/include/array/impl/DataBuffer.cpp new file mode 100644 index 000000000..fae25478f --- /dev/null +++ b/libnd4j/include/array/impl/DataBuffer.cpp @@ -0,0 +1,301 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#include +#include +#include +#include + +namespace nd4j { + ///// IMLEMENTATION OF COMMON METHODS ///// + + +//////////////////////////////////////////////////////////////////////// +// default constructor + DataBuffer::DataBuffer() { + + _primaryBuffer = nullptr; + _specialBuffer = nullptr; + _lenInBytes = 0; + _dataType = INT8; + _workspace = nullptr; + _isOwnerPrimary = false; + _isOwnerSpecial = false; + _deviceId = nd4j::AffinityManager::currentDeviceId(); + + setCountersToZero(); + } + +//////////////////////////////////////////////////////////////////////// +// copy constructor + DataBuffer::DataBuffer(const DataBuffer &other) { + + throw std::runtime_error("DataBuffer copy constructor: we don't expect using of this constructor!"); + + _lenInBytes = other._lenInBytes; + _dataType = other._dataType; + _workspace = other._workspace; + + _primaryBuffer = nullptr; + _specialBuffer = nullptr; + + _deviceId.store(other._deviceId.load()); + + setCountersToZero(); + + allocateBuffers(); + copyBufferFrom(other); + } + +//////////////////////////////////////////////////////////////////////// + DataBuffer::DataBuffer(void* primary, void* special, + const size_t lenInBytes, const DataType dataType, + const bool isOwnerPrimary, const bool isOwnerSpecial, + memory::Workspace* workspace) { + + if (primary == nullptr && special == nullptr) + throw std::runtime_error("DataBuffer constructor: can't be initialized with both nullptr buffers !"); + + _primaryBuffer = primary; + _specialBuffer = special; + _lenInBytes = lenInBytes; + _dataType = dataType; + _workspace = workspace; + _isOwnerPrimary = isOwnerPrimary; + _isOwnerSpecial = isOwnerSpecial; + _deviceId = nd4j::AffinityManager::currentDeviceId(); + + setCountersToZero(); + + if(primary != nullptr) + readPrimary(); + if(special != nullptr) + readSpecial(); + } + +//////////////////////////////////////////////////////////////////////// + DataBuffer::DataBuffer(void* primary, const size_t lenInBytes, const DataType dataType, const bool isOwnerPrimary, memory::Workspace* workspace): + DataBuffer(primary, nullptr, lenInBytes, dataType, isOwnerPrimary, false, workspace) { + + syncToSpecial(true); + } + +//////////////////////////////////////////////////////////////////////// +// copies data from hostBuffer to own memory buffer + DataBuffer::DataBuffer(const void* hostBuffer, const DataType dataType, const size_t lenInBytes, memory::Workspace* workspace) { + + if (hostBuffer == nullptr) + throw std::runtime_error("DataBuffer constructor: can't be initialized with nullptr host buffer !"); + if (lenInBytes == 0) + throw std::runtime_error("DataBuffer constructor: can't be initialized with zero length !"); + + _primaryBuffer = nullptr; + _specialBuffer = nullptr; + _lenInBytes = lenInBytes; + _dataType = dataType; + _workspace = workspace; + + _deviceId = nd4j::AffinityManager::currentDeviceId(); + + setCountersToZero(); + + allocateBuffers(); + + copyBufferFromHost(hostBuffer, lenInBytes); + } + +//////////////////////////////////////////////////////////////////////// + DataBuffer::DataBuffer(const size_t lenInBytes, const DataType dataType, memory::Workspace* workspace, const bool allocBoth) { + + _dataType = dataType; + _workspace = workspace; + _lenInBytes = lenInBytes; + + _primaryBuffer = nullptr; + _specialBuffer = nullptr; + + _deviceId = nd4j::AffinityManager::currentDeviceId(); + + setCountersToZero(); + + if(lenInBytes != 0) { + allocateBuffers(allocBoth); + writeSpecial(); + } + } + +//////////////////////////////////////////////////////////////////////// +// move constructor + DataBuffer::DataBuffer(DataBuffer&& other) { + + _primaryBuffer = other._primaryBuffer; + _specialBuffer = other._specialBuffer; + _lenInBytes = other._lenInBytes; + _dataType = other._dataType; + _workspace = other._workspace; + _isOwnerPrimary = other._isOwnerPrimary; + _isOwnerSpecial = other._isOwnerSpecial; + _deviceId.store(other._deviceId); + + copyCounters(other); + + other._primaryBuffer = other._specialBuffer = nullptr; + other.setAllocFlags(false, false); + other._lenInBytes = 0; + } + +//////////////////////////////////////////////////////////////////////// +// assignment operator + DataBuffer& DataBuffer::operator=(const DataBuffer& other) { + + if (this == &other) + return *this; + + deleteBuffers(); + + _lenInBytes = other._lenInBytes; + _dataType = other._dataType; + _workspace = other._workspace; + + allocateBuffers(); + copyBufferFrom(other); + + return *this; + } + +//////////////////////////////////////////////////////////////////////// +// move assignment operator + DataBuffer& DataBuffer::operator=(DataBuffer&& other) noexcept { + + if (this == &other) + return *this; + + deleteBuffers(); + + _primaryBuffer = other._primaryBuffer; + _specialBuffer = other._specialBuffer; + _lenInBytes = other._lenInBytes; + _dataType = other._dataType; + _workspace = other._workspace; + _isOwnerPrimary = other._isOwnerPrimary; + _isOwnerSpecial = other._isOwnerSpecial; + + copyCounters(other); + + other._primaryBuffer = other._specialBuffer = nullptr; + other.setAllocFlags(false, false); + other._lenInBytes = 0; + + return *this; + } + +//////////////////////////////////////////////////////////////////////// + void* DataBuffer::primary() { + return _primaryBuffer; + } + +//////////////////////////////////////////////////////////////////////// + void* DataBuffer::special() { + return _specialBuffer; + } + +//////////////////////////////////////////////////////////////////////// + DataType DataBuffer::getDataType() { + return _dataType; + } + +//////////////////////////////////////////////////////////////////////// + size_t DataBuffer::getLenInBytes() const { + return _lenInBytes; + } + + +//////////////////////////////////////////////////////////////////////// + void DataBuffer::allocatePrimary() { + + if (_primaryBuffer == nullptr && getLenInBytes() > 0) { + ALLOCATE(_primaryBuffer, _workspace, getLenInBytes(), int8_t); + _isOwnerPrimary = true; + } + } + +//////////////////////////////////////////////////////////////////////// + void DataBuffer::setAllocFlags(const bool isOwnerPrimary, const bool isOwnerSpecial) { + + _isOwnerPrimary = isOwnerPrimary; + _isOwnerSpecial = isOwnerSpecial; + } + +//////////////////////////////////////////////////////////////////////// + void DataBuffer::deletePrimary() { + + if(_isOwnerPrimary && _primaryBuffer != nullptr && getLenInBytes() != 0) { + auto p = reinterpret_cast(_primaryBuffer); + RELEASE(p, _workspace); + _primaryBuffer = nullptr; + _isOwnerPrimary = false; + } + } + +//////////////////////////////////////////////////////////////////////// + void DataBuffer::deleteBuffers() { + + deletePrimary(); + deleteSpecial(); + _lenInBytes = 0; + } + +//////////////////////////////////////////////////////////////////////// + DataBuffer::~DataBuffer() { + + deleteBuffers(); + } + + void DataBuffer::setPrimaryBuffer(void *buffer, size_t length) { + if (_primaryBuffer != nullptr && _isOwnerPrimary) { + deletePrimary(); + } + _primaryBuffer = buffer; + _isOwnerPrimary = false; + _lenInBytes = length * DataTypeUtils::sizeOf(_dataType); + } + + void DataBuffer::setSpecialBuffer(void *buffer, size_t length) { + this->setSpecial(buffer, false); + _lenInBytes = length * DataTypeUtils::sizeOf(_dataType); + } + + void DataBuffer::setDataType(DataType dataType) { + _dataType = dataType; + } + + int DataBuffer::deviceId() const { + return _deviceId.load(); + } + + void DataBuffer::close() { + this->deleteBuffers(); + } + + void DataBuffer::setDeviceId(int deviceId) { + _deviceId = deviceId; + } +} diff --git a/libnd4j/include/array/impl/InteropDataBuffer.cpp b/libnd4j/include/array/impl/InteropDataBuffer.cpp new file mode 100644 index 000000000..cffc1462b --- /dev/null +++ b/libnd4j/include/array/impl/InteropDataBuffer.cpp @@ -0,0 +1,146 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include +#include +#include + +namespace nd4j { + InteropDataBuffer::InteropDataBuffer(InteropDataBuffer &dataBuffer, uint64_t length, uint64_t offset) { + _dataBuffer = dataBuffer.getDataBuffer(); + + // offset is always absolute to the original buffer + _offset = offset; + + if (_offset + length > _dataBuffer->getLenInBytes()) { + throw std::runtime_error("offset + length is higher than original length"); + } + } + + InteropDataBuffer::InteropDataBuffer(std::shared_ptr databuffer) { + _dataBuffer = databuffer; + } + + InteropDataBuffer::InteropDataBuffer(size_t elements, nd4j::DataType dtype, bool allocateBoth) { + if (elements == 0) { + _dataBuffer = std::make_shared(); + _dataBuffer->setDataType(dtype); + } else { + _dataBuffer = std::make_shared(elements, dtype, nullptr, allocateBoth); + } + } + + std::shared_ptr InteropDataBuffer::getDataBuffer() const { + return _dataBuffer; + } + + std::shared_ptr InteropDataBuffer::dataBuffer() { + return _dataBuffer; + } + + void* InteropDataBuffer::primary() const { + return reinterpret_cast(_dataBuffer->primary()) + _offset; + } + + void* InteropDataBuffer::special() const { + return reinterpret_cast(_dataBuffer->special()) + _offset; + } + + void InteropDataBuffer::setPrimary(void* ptr, size_t length) { + _dataBuffer->setPrimaryBuffer(ptr, length); + } + + void InteropDataBuffer::setSpecial(void* ptr, size_t length) { + _dataBuffer->setSpecialBuffer(ptr, length); + } + + uint64_t InteropDataBuffer::offset() const { + return _offset; + } + + void InteropDataBuffer::setOffset(uint64_t offset) { + _offset = offset; + } + + int InteropDataBuffer::deviceId() const { + return _dataBuffer->deviceId(); + } + + + void InteropDataBuffer::registerSpecialUse(const std::vector& writeList, const std::vector& readList) { + for (const auto &v:writeList) { + if (v == nullptr) + continue; + + v->getDataBuffer()->writeSpecial(); + } + } + + void InteropDataBuffer::prepareSpecialUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables) { + auto currentDeviceId = nd4j::AffinityManager::currentDeviceId(); + for (const auto &v:readList) { + if (v == nullptr) + continue; + + if (v->getDataBuffer()->deviceId() != currentDeviceId) + v->getDataBuffer()->migrate(); + + v->getDataBuffer()->syncToSpecial(); + } + + // we don't tick write list, only ensure the same device affinity + for (const auto &v:writeList) { + if (v == nullptr) + continue; + + // special case for legacy ops - views can be updated on host side, thus original array can be not updated + if (!v->getDataBuffer()->isSpecialActual()) + v->getDataBuffer()->syncToSpecial(); + + if (v->getDataBuffer()->deviceId() != currentDeviceId) + v->getDataBuffer()->migrate(); + } + } + + void InteropDataBuffer::registerPrimaryUse(const std::vector& writeList, const std::vector& readList) { + for (const auto &v:writeList) { + if (v == nullptr) + continue; + } + } + + void InteropDataBuffer::preparePrimaryUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables) { + for (const auto &v:readList) { + if (v == nullptr) + continue; + + v->getDataBuffer()->syncToPrimary(LaunchContext::defaultContext()); + } + } + + void InteropDataBuffer::expand(size_t newlength) { + _dataBuffer->expand(newlength * DataTypeUtils::sizeOf(_dataBuffer->getDataType())); + } + + void InteropDataBuffer::setDeviceId(int deviceId) { + _dataBuffer->setDeviceId(deviceId); + } +} diff --git a/libnd4j/include/execution/cuda/ContextBuffers.cu b/libnd4j/include/execution/cuda/ContextBuffers.cu index 435858462..e018cf807 100644 --- a/libnd4j/include/execution/cuda/ContextBuffers.cu +++ b/libnd4j/include/execution/cuda/ContextBuffers.cu @@ -138,7 +138,7 @@ namespace nd4j { if (res != 0) throw cuda_exception::build("_reductionPointer allocation failed", res); - res = cudaMalloc(reinterpret_cast(&_scalarPointer), 16); + res = cudaHostAlloc(reinterpret_cast(&_scalarPointer), 16, cudaHostAllocDefault); if (res != 0) throw cuda_exception::build("_scalarPointer allocation failed", res); diff --git a/libnd4j/include/graph/Context.h b/libnd4j/include/graph/Context.h index dc36e0704..57988da79 100644 --- a/libnd4j/include/graph/Context.h +++ b/libnd4j/include/graph/Context.h @@ -185,9 +185,11 @@ namespace nd4j { void setInputArray(int index, NDArray *array, bool removable = false); void setInputArray(int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo); + void setInputArray(int index, void *databuffer, void *shapeInfo, void *specialShapeInfo); void setOutputArray(int index, NDArray *array, bool removable = false); void setOutputArray(int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo); + void setOutputArray(int index, void *databuffer, void *shapeInfo, void *specialShapeInfo); void setTArguments(double *arguments, int numberOfArguments); void setIArguments(Nd4jLong *arguments, int numberOfArguments); diff --git a/libnd4j/include/graph/impl/Context.cpp b/libnd4j/include/graph/impl/Context.cpp index c8c76c7df..2725a2667 100644 --- a/libnd4j/include/graph/impl/Context.cpp +++ b/libnd4j/include/graph/impl/Context.cpp @@ -21,6 +21,7 @@ #include #include #include +#include namespace nd4j { @@ -426,6 +427,44 @@ namespace nd4j { array->setContext(_context); } + void Context::setInputArray(int index, void *vdatabuffer, void *shapeInfo, void *specialShapeInfo) { + auto dataBuffer = reinterpret_cast(vdatabuffer); + + if (_fastpath_in.size() < index + 1) + _fastpath_in.resize(index+1); + + NDArray *array; + if (dataBuffer != nullptr) + array = new NDArray(dataBuffer->dataBuffer(), reinterpret_cast(shapeInfo), nd4j::LaunchContext::defaultContext(), dataBuffer->offset() / DataTypeUtils::sizeOf(ArrayOptions::dataType(reinterpret_cast(shapeInfo)))); + else + array = new NDArray(nullptr, nullptr, reinterpret_cast(shapeInfo)); + + _fastpath_in[index] = array; + _handles.emplace_back(array); + + if (_context != nullptr) + array->setContext(_context); + } + + void Context::setOutputArray(int index, void *vdatabuffer, void *shapeInfo, void *specialShapeInfo) { + auto dataBuffer = reinterpret_cast(vdatabuffer); + + if (_fastpath_out.size() < index + 1) + _fastpath_out.resize(index+1); + + NDArray *array; + if (dataBuffer != nullptr) + array = new NDArray(dataBuffer->dataBuffer(), reinterpret_cast(shapeInfo), nd4j::LaunchContext::defaultContext(), dataBuffer->offset() / DataTypeUtils::sizeOf(ArrayOptions::dataType(reinterpret_cast(shapeInfo)))); + else + array = new NDArray(nullptr, nullptr, reinterpret_cast(shapeInfo)); + + _fastpath_out[index] = array; + _handles.emplace_back(array); + + if (_context != nullptr) + array->setContext(_context); + } + void Context::setTArguments(double *arguments, int numberOfArguments) { _tArgs.clear(); _tArgs.reserve(numberOfArguments); diff --git a/libnd4j/include/graph/scheme/array.fbs b/libnd4j/include/graph/scheme/array.fbs index 91e338500..2ffce58bd 100644 --- a/libnd4j/include/graph/scheme/array.fbs +++ b/libnd4j/include/graph/scheme/array.fbs @@ -43,6 +43,8 @@ enum DType:byte { QINT16, BFLOAT16 = 17, UTF8 = 50, + UTF16 = 51, + UTF32 = 52, } // this structure describe NDArray diff --git a/libnd4j/include/helpers/DebugHelper.h b/libnd4j/include/helpers/DebugHelper.h index 945bebe8e..3c3fe1d58 100644 --- a/libnd4j/include/helpers/DebugHelper.h +++ b/libnd4j/include/helpers/DebugHelper.h @@ -34,8 +34,6 @@ #include #include -#define checkCudaErrors(ERR) if (ERR != 0) {throw std::runtime_error("CUDA stream synchronization failed");} - #endif #include namespace nd4j { diff --git a/libnd4j/include/helpers/StringUtils.h b/libnd4j/include/helpers/StringUtils.h index 1a450450f..2a562de4b 100644 --- a/libnd4j/include/helpers/StringUtils.h +++ b/libnd4j/include/helpers/StringUtils.h @@ -25,6 +25,8 @@ #include #include #include +#include +#include namespace nd4j { class ND4J_EXPORT StringUtils { @@ -53,6 +55,36 @@ namespace nd4j { return result; } + + /** + * This method returns number of needle matches within haystack + * PLEASE NOTE: this method operates on 8-bit arrays interpreted as uint8 + * + * @param haystack + * @param haystackLength + * @param needle + * @param needleLength + * @return + */ + static uint64_t countSubarrays(const void *haystack, uint64_t haystackLength, const void *needle, uint64_t needleLength); + + /** + * This method returns number of bytes used for string NDArrays content + * PLEASE NOTE: this doesn't include header + * + * @param array + * @return + */ + static uint64_t byteLength(const NDArray &array); + + /** + * This method splits a string into substring by delimiter + * + * @param haystack + * @param delimiter + * @return + */ + static std::vector split(const std::string &haystack, const std::string &delimiter); }; } diff --git a/libnd4j/include/helpers/impl/StringUtils.cpp b/libnd4j/include/helpers/impl/StringUtils.cpp index cd0383a75..faace2c63 100644 --- a/libnd4j/include/helpers/impl/StringUtils.cpp +++ b/libnd4j/include/helpers/impl/StringUtils.cpp @@ -19,7 +19,58 @@ // #include +#include namespace nd4j { + static FORCEINLINE bool match(const uint8_t *haystack, const uint8_t *needle, uint64_t length) { + for (int e = 0; e < length; e++) + if (haystack[e] != needle[e]) + return false; + return true; + } + + uint64_t StringUtils::countSubarrays(const void *vhaystack, uint64_t haystackLength, const void *vneedle, uint64_t needleLength) { + auto haystack = reinterpret_cast(vhaystack); + auto needle = reinterpret_cast(vneedle); + + uint64_t number = 0; + + for (uint64_t e = 0; e < haystackLength - needleLength; e++) { + if (match(&haystack[e], needle, needleLength)) + number++; + } + + return number; + } + + + uint64_t StringUtils::byteLength(const NDArray &array) { + if (!array.isS()) + throw nd4j::datatype_exception::build("StringUtils::byteLength expects one of String types;", array.dataType()); + + uint64_t result = 0; + + // our buffer stores offsets, and the last value is basically number of bytes used + auto buffer = array.bufferAsT(); + result = buffer[array.lengthOf()]; + + return result; + } + + std::vector StringUtils::split(const std::string &haystack, const std::string &delimiter) { + std::vector output; + + std::string::size_type prev_pos = 0, pos = 0; + + // iterating through the haystack till the end + while((pos = haystack.find(delimiter, pos)) != std::string::npos) { + output.emplace_back(haystack.substr(prev_pos, pos-prev_pos)); + prev_pos = ++pos; + } + + output.emplace_back(haystack.substr(prev_pos, pos - prev_pos)); // Last word + + return output; + } } diff --git a/libnd4j/include/loops/cpu/reduce/reduce_bool.cpp b/libnd4j/include/loops/cpu/reduce/reduce_bool.cpp index 882b1740e..1ee820853 100644 --- a/libnd4j/include/loops/cpu/reduce/reduce_bool.cpp +++ b/libnd4j/include/loops/cpu/reduce/reduce_bool.cpp @@ -20,7 +20,6 @@ // #include -#include #include #include #include diff --git a/libnd4j/include/loops/cpu/reduce/reduce_float.hpp b/libnd4j/include/loops/cpu/reduce/reduce_float.hpp index 8286d209c..d0a80a3f5 100644 --- a/libnd4j/include/loops/cpu/reduce/reduce_float.hpp +++ b/libnd4j/include/loops/cpu/reduce/reduce_float.hpp @@ -20,7 +20,6 @@ // #include -#include #include #include #include diff --git a/libnd4j/include/loops/cpu/reduce/reduce_long.cpp b/libnd4j/include/loops/cpu/reduce/reduce_long.cpp index 76dc209f6..e53c9ac8e 100644 --- a/libnd4j/include/loops/cpu/reduce/reduce_long.cpp +++ b/libnd4j/include/loops/cpu/reduce/reduce_long.cpp @@ -20,7 +20,6 @@ // #include -#include #include #include #include diff --git a/libnd4j/include/loops/cpu/reduce/reduce_same.cpp b/libnd4j/include/loops/cpu/reduce/reduce_same.cpp index cbd7e6e12..929d9c4ff 100644 --- a/libnd4j/include/loops/cpu/reduce/reduce_same.cpp +++ b/libnd4j/include/loops/cpu/reduce/reduce_same.cpp @@ -20,7 +20,6 @@ // #include -#include #include #include #include diff --git a/libnd4j/include/op_boilerplate.h b/libnd4j/include/op_boilerplate.h index 52685a2c9..97f33569b 100644 --- a/libnd4j/include/op_boilerplate.h +++ b/libnd4j/include/op_boilerplate.h @@ -1624,4 +1624,9 @@ #define PARAMETRIC_D() [&] (Parameters &p) -> Context* + +#ifdef __CUDABLAS__ +#define checkCudaErrors(ERR) if (ERR != 0) {throw std::runtime_error("CUDA stream synchronization failed");} +#endif + #endif diff --git a/libnd4j/include/ops/declarable/CustomOperations.h b/libnd4j/include/ops/declarable/CustomOperations.h index 7d699c49b..0b0e42809 100644 --- a/libnd4j/include/ops/declarable/CustomOperations.h +++ b/libnd4j/include/ops/declarable/CustomOperations.h @@ -40,6 +40,9 @@ #include #include #include +#include +#include +#include #include #include #include diff --git a/libnd4j/include/ops/declarable/generic/compat/README.md b/libnd4j/include/ops/declarable/generic/compat/README.md new file mode 100644 index 000000000..ff44ae4c1 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/compat/README.md @@ -0,0 +1 @@ +This folder contains operations required for compatibility with TF and other frameworks. \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/compat/compat_sparse_to_dense.cpp b/libnd4j/include/ops/declarable/generic/compat/compat_sparse_to_dense.cpp new file mode 100644 index 000000000..4a84dbdac --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/compat/compat_sparse_to_dense.cpp @@ -0,0 +1,73 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_split_string) + +#include +#include + +namespace nd4j { + namespace ops { + CUSTOM_OP_IMPL(compat_sparse_to_dense, 4, 1, false, 0, 0) { + auto indices = INPUT_VARIABLE(0); + auto shape = INPUT_VARIABLE(1); + auto values = INPUT_VARIABLE(2); + NDArray *def = nullptr; + + auto output = OUTPUT_VARIABLE(0); + + if (block.width() > 3) + def = INPUT_VARIABLE(3); + + nd4j::ops::helpers::compat_sparse_to_dense(*values, *indices, def, *output); + + return Status::OK(); + }; + + DECLARE_SHAPE_FN(compat_sparse_to_dense) { + auto indices = INPUT_VARIABLE(0); + auto shape = INPUT_VARIABLE(1); + auto values = INPUT_VARIABLE(2); + + if (block.width() > 3) { + auto def = INPUT_VARIABLE(3); + + REQUIRE_TRUE(def->dataType() == values->dataType() && def->isScalar(), 0, "compat_sparse_to_dense: default value must be a scalar of the same data type as actual values") + }; + + auto dtype = values->dataType(); + + // basically output shape is defined by the type of input, and desired shape input + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'c', shape->getBufferAsVector())); + } + + DECLARE_TYPES(compat_sparse_to_dense) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS}) // indices + ->setAllowedInputTypes(1, {ALL_INTS}) // shape + ->setAllowedInputTypes(2,nd4j::DataType::ANY) // sparse values + ->setAllowedInputTypes(3,nd4j::DataType::ANY) // default value + ->setAllowedOutputTypes(nd4j::DataType::ANY); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/compat/compat_string_split.cpp b/libnd4j/include/ops/declarable/generic/compat/compat_string_split.cpp new file mode 100644 index 000000000..9d7b57ee4 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/compat/compat_string_split.cpp @@ -0,0 +1,140 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_split_string) + +#include +#include + +namespace nd4j { + namespace ops { + CUSTOM_OP_IMPL(compat_string_split, 2, 2, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto delim = INPUT_VARIABLE(1); + + auto indices = OUTPUT_VARIABLE(0); + auto values = OUTPUT_VARIABLE(1); + + auto d = delim->e(0); + + input->syncToHost(); + delim->syncToHost(); + + // output rank N+1 wrt input rank + std::vector ocoords(input->rankOf() + 1); + std::vector icoords(input->rankOf()); + + // getting buffer lengths + // FIXME: it'll be bigger, since it'll include delimiters, + auto outputLength = StringUtils::byteLength(*input); + + uint64_t ss = 0L; + Nd4jLong ic = 0L; + // loop through each string within tensor + for (auto e = 0L; e < input->lengthOf(); e++) { + // now we should map substring to indices + auto s = input->e(e); + + // getting base index + shape::index2coords(e, input->shapeInfo(), icoords.data()); + + // getting number of substrings + auto cnt = StringUtils::countSubarrays(s.c_str(), s.length(), d.c_str(), d.length()) + 1; + + // filling output indices + for (uint64_t f = 0; f < cnt; f++) { + for (auto v: icoords) + indices->p(ic++, v); + + // last index + indices->p(ic++, f); + } + + ss += cnt; + } + + // process strings now + std::vector strings; + for (auto e = 0L; e < input->lengthOf(); e++) { + auto split = StringUtils::split(input->e(e), d); + + for (const auto &s:split) + strings.emplace_back(s); + } + + // now once we have all strings in single vector time to fill + auto tmp = NDArrayFactory::string('c', {(Nd4jLong) strings.size()}, strings); + auto blen = StringUtils::byteLength(tmp) + ShapeUtils::stringBufferHeaderRequirements(strings.size()); + + // for CUDA mostly + values->dataBuffer()->allocatePrimary(); + values->dataBuffer()->expand(blen); + memcpy(values->buffer(), tmp.buffer(), blen); + values->tickWriteHost(); + + // special case, for future use + indices->syncToDevice(); + values->syncToDevice(); + + // we have to tick buffers + values->dataBuffer()->writePrimary(); + values->dataBuffer()->readSpecial(); + + return Status::OK(); + }; + + DECLARE_SHAPE_FN(compat_string_split) { + auto input = INPUT_VARIABLE(0); + auto delim = INPUT_VARIABLE(1); + + auto d = delim->e(0); + + // count number of delimiter substrings in all strings within input tensor + uint64_t cnt = 0; + for (auto e = 0L; e < input->lengthOf(); e++) { + // FIXME: bad, not UTF-compatible + auto s = input->e(e); + + // each substring we see in haystack, splits string in two parts. so we should add 1 to the number of subarrays + cnt += StringUtils::countSubarrays(s.c_str(), s.length(), d.c_str(), d.length()) + 1; + } + + // shape calculations + // virtual tensor rank will be N+1, for N rank input array, where data will be located at the biggest dimension + // values tensor is going to be vector always + // indices tensor is going to be vector with length equal to values.length * output rank + + auto valuesShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(cnt, nd4j::DataType::UTF8); + auto indicesShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(cnt * (input->rankOf() + 1), nd4j::DataType::INT64); + + return SHAPELIST(indicesShape, valuesShape); + } + + DECLARE_TYPES(compat_string_split) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_STRINGS}) + ->setAllowedOutputTypes(0, {ALL_INDICES}) + ->setAllowedOutputTypes(1, {ALL_STRINGS}); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/datatypes/bitcast.cpp b/libnd4j/include/ops/declarable/generic/datatypes/bitcast.cpp index 24f96f7a7..8591d3449 100644 --- a/libnd4j/include/ops/declarable/generic/datatypes/bitcast.cpp +++ b/libnd4j/include/ops/declarable/generic/datatypes/bitcast.cpp @@ -47,8 +47,7 @@ namespace nd4j { } // just memcpy data -// output->dataBuffer()->copyBufferFrom(*input->dataBuffer()); // as variant - DataBuffer::memcpy(*output->dataBuffer(), *input->dataBuffer()); // this is modern approach + DataBuffer::memcpy(*output->dataBuffer(), *input->dataBuffer()); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/strings/split_string.cpp b/libnd4j/include/ops/declarable/generic/strings/split_string.cpp new file mode 100644 index 000000000..4af4e3aac --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/strings/split_string.cpp @@ -0,0 +1,50 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_split_string) + +#include + +namespace nd4j { + namespace ops { + CUSTOM_OP_IMPL(split_string, 2, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto delim = INPUT_VARIABLE(1); + + return Status::OK(); + }; + + DECLARE_SHAPE_FN(split_string) { + auto input = INPUT_VARIABLE(0); + auto delim = INPUT_VARIABLE(1); + + return SHAPELIST(); + } + + DECLARE_TYPES(split_string) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_STRINGS}) + ->setAllowedOutputTypes({ALL_STRINGS}); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/util/print_affinity.cpp b/libnd4j/include/ops/declarable/generic/util/print_affinity.cpp new file mode 100644 index 000000000..6b1514ab9 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/util/print_affinity.cpp @@ -0,0 +1,52 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_print_affinity) + +#include +#include + +namespace nd4j { + namespace ops { + CUSTOM_OP_IMPL(print_affinity, 1, 1, true, 0, 0) { + // TODO: make this op compatible with ArrayList etc + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + nd4j_printf(": Actuality: [HOST: %s; DEVICE: %s]; affinity: [%i]; Pointers: [HOST: %p; DEVICE: %p]; DataBuffer length: %lld\n", block.nodeId(), input->isActualOnHostSide() ? "true" : "false", input->isActualOnDeviceSide() ? "true" : "false", input->dataBuffer()->deviceId(), input->getBuffer(), input->getSpecialBuffer(), input->dataBuffer()->getLenInBytes()); + + return Status::OK(); + } + + DECLARE_TYPES(print_affinity) { + getOpDescriptor() + ->setAllowedInputTypes(0, nd4j::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_STRINGS}) + ->setAllowedOutputTypes(0, nd4j::DataType::INT32); + } + + DECLARE_SHAPE_FN(print_affinity) { + return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT32)); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/util/print_variable.cpp b/libnd4j/include/ops/declarable/generic/util/print_variable.cpp new file mode 100644 index 000000000..6828b2f90 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/util/print_variable.cpp @@ -0,0 +1,77 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// +#include +#if NOT_EXCLUDED(OP_print_variable) + +#include +#include + +namespace nd4j { + namespace ops { + CUSTOM_OP_IMPL(print_variable, 1, 1, true, 0, 0) { + // TODO: make this op compatible with ArrayList etc + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + std::string str; + + if (block.width() == 2) { + auto message = INPUT_VARIABLE(1); + REQUIRE_TRUE(message->isS(), 0, "print_variable: message variable must be a String"); + + str = message->e(0); + } + + bool printSpecial = false; + if (block.numB() > 0) + printSpecial = B_ARG(0); + + if (printSpecial && !nd4j::Environment::getInstance()->isCPU()) { + // only specific backends support special printout. for cpu-based backends it's the same as regular print + + if (block.width() == 2) + helpers::print_special(*block.launchContext(), *input, str); + else + helpers::print_special(*block.launchContext(), *input); + } else { + // optionally add message to the print out + if (block.width() == 2) { + input->printIndexedBuffer(str.c_str()); + } else { + input->printIndexedBuffer(); + } + } + + return Status::OK(); + } + + DECLARE_TYPES(print_variable) { + getOpDescriptor() + ->setAllowedInputTypes(0, nd4j::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_STRINGS}) + ->setAllowedOutputTypes(0, nd4j::DataType::INT32); + } + + DECLARE_SHAPE_FN(print_variable) { + return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT32)); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/compat.h b/libnd4j/include/ops/declarable/headers/compat.h new file mode 100644 index 000000000..8ce73153e --- /dev/null +++ b/libnd4j/include/ops/declarable/headers/compat.h @@ -0,0 +1,54 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SAMEDIFF_COMPAT_H +#define SAMEDIFF_COMPAT_H + +#include + +namespace nd4j { + namespace ops { + /** + * This operation splits input string into pieces separated by delimiter + * PLEASE NOTE: This implementation is compatible with TF 1.x + * + * Input[0] - string to split + * Input[1] - delimiter + * + * Returns: + * Output[0] - indices tensor + * Output[1] - values tensor + */ + #if NOT_EXCLUDED(OP_compat_string_split) + DECLARE_CUSTOM_OP(compat_string_split, 2, 2, false, 0, 0); + #endif + + /** + * This operation converts TF sparse array representation to dense NDArray + */ + #if NOT_EXCLUDED(OP_compat_sparse_to_dense) + DECLARE_CUSTOM_OP(compat_sparse_to_dense, 4, 1, false, 0, 0); + #endif + + } +} + + +#endif //SAMEDIFF_COMPAT_H diff --git a/libnd4j/include/ops/declarable/headers/strings.h b/libnd4j/include/ops/declarable/headers/strings.h new file mode 100644 index 000000000..0849f118a --- /dev/null +++ b/libnd4j/include/ops/declarable/headers/strings.h @@ -0,0 +1,42 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SAMEDIFF_STRINGS_H +#define SAMEDIFF_STRINGS_H + +#include + +namespace nd4j { + namespace ops { + /** + * This operation splits input string into pieces separated by delimiter + * + * Input[0] - string to split + * Input[1] - delimiter + */ + #if NOT_EXCLUDED(OP_split_string) + DECLARE_CUSTOM_OP(split_string, 2, 1, true, 0, 0); + #endif + + } +} + + +#endif //SAMEDIFF_STRINGS_H diff --git a/libnd4j/include/ops/declarable/headers/util.h b/libnd4j/include/ops/declarable/headers/util.h new file mode 100644 index 000000000..aa1f52363 --- /dev/null +++ b/libnd4j/include/ops/declarable/headers/util.h @@ -0,0 +1,44 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef LIBND4J_UTILS_H +#define LIBND4J_UTILS_H + +#include + +namespace nd4j { + namespace ops { + /** + * This operation prints out NDArray content, either on host or device. + */ + #if NOT_EXCLUDED(OP_print_variable) + DECLARE_CUSTOM_OP(print_variable, 1, 1, true, 0, 0); + #endif + + /** + * This operation prints out affinity & locality status of given NDArray + */ + #if NOT_EXCLUDED(OP_print_affinity) + DECLARE_CUSTOM_OP(print_affinity, 1, 1, true, 0, 0); + #endif + } +} + +#endif //LIBND4J_UTILS_H diff --git a/libnd4j/include/ops/declarable/helpers/cpu/print_variable.cpp b/libnd4j/include/ops/declarable/helpers/cpu/print_variable.cpp new file mode 100644 index 000000000..293518be6 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/print_variable.cpp @@ -0,0 +1,31 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace nd4j { + namespace ops { + namespace helpers { + void print_special(LaunchContext &ctx, const NDArray &array, const std::string &message) { + array.printIndexedBuffer(message.c_str()); + } + } + } +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/activations.cu b/libnd4j/include/ops/declarable/helpers/cuda/activations.cu index b2a13bfce..7bddb00fe 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/activations.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/activations.cu @@ -40,15 +40,11 @@ __global__ void preluCuda(const void *vx, const Nd4jLong *xShapeInfo, const auto y = reinterpret_cast(vy); auto z = reinterpret_cast(vz); - __shared__ Nd4jLong xzLen, totalThreads, *sharedMem; + __shared__ Nd4jLong xzLen; __shared__ int xzRank, yRank; if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - xzLen = shape::length(xShapeInfo); - totalThreads = gridDim.x * blockDim.x; xzRank = shape::rank(xShapeInfo); yRank = shape::rank(yShapeInfo); @@ -56,18 +52,15 @@ __global__ void preluCuda(const void *vx, const Nd4jLong *xShapeInfo, __syncthreads(); const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - Nd4jLong* coords = sharedMem + threadIdx.x * xzRank; - - for (int i = tid; i < xzLen; i += totalThreads) { + Nd4jLong coords[MAX_RANK]; + for (int i = tid; i < xzLen; i += blockDim.x * gridDim.x) { shape::index2coords(i, xShapeInfo, coords); const auto xzOffset = shape::getOffset(xShapeInfo, coords); - const auto xVal = x[xzOffset]; if(xVal < 0) { - for (uint j = 0; j < yRank; ++j) if(yShapeInfo[j + 1] == 1) coords[j + 1] = 0; @@ -82,7 +75,6 @@ __global__ void preluCuda(const void *vx, const Nd4jLong *xShapeInfo, /////////////////////////////////////////////////////////////////// template linkage void preluCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz) { - preluCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz); } @@ -91,9 +83,9 @@ void prelu(nd4j::LaunchContext * context, const NDArray& input, const NDArray& a PointersManager manager(context, "prelu"); - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = input.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + const int threadsPerBlock = 256; + const int blocksPerGrid = 512; + const int sharedMem = 512; const auto xType = input.dataType(); const auto yType = alpha.dataType(); @@ -119,13 +111,10 @@ __global__ linkage void preluBPCuda(const void *vIn, const Nd4jLong *inShapeI auto dLdI = reinterpret_cast(vdLdI); auto dLdA = reinterpret_cast(vdLdA); - __shared__ Nd4jLong inLen, totalThreads, *sharedMem; + __shared__ Nd4jLong inLen, totalThreads; __shared__ int inRank, alphaRank; if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - inLen = shape::length(inShapeInfo); totalThreads = gridDim.x * blockDim.x; @@ -135,10 +124,9 @@ __global__ linkage void preluBPCuda(const void *vIn, const Nd4jLong *inShapeI __syncthreads(); const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - Nd4jLong* coords = sharedMem + threadIdx.x * inRank; + Nd4jLong coords[MAX_RANK]; for (int i = tid; i < inLen; i += totalThreads) { - shape::index2coords(i, inShapeInfo, coords); const auto inOffset = shape::getOffset(inShapeInfo, coords); @@ -175,14 +163,13 @@ __host__ linkage void preluBPCudaLauncher(const int blocksPerGrid, const int thr ////////////////////////////////////////////////////////////////////////// void preluBP(nd4j::LaunchContext* context, const NDArray& input, const NDArray& alpha, const NDArray& dLdO, NDArray& dLdI, NDArray& dLdA) { - - dLdA.nullify(); + dLdA.nullify(); PointersManager manager(context, "preluBP"); - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = input.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + const int threadsPerBlock = 256; + const int blocksPerGrid = 512; + const int sharedMem = 512; const auto xType = input.dataType(); const auto zType = alpha.dataType(); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/print_variable.cu b/libnd4j/include/ops/declarable/helpers/cuda/print_variable.cu new file mode 100644 index 000000000..88d2b5937 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/print_variable.cu @@ -0,0 +1,61 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include + +namespace nd4j { + namespace ops { + namespace helpers { + template + static _CUDA_G void print_device(const void *special, const Nd4jLong *shapeInfo) { + auto length = shape::length(shapeInfo); + auto x = reinterpret_cast(special); + + // TODO: add formatting here + printf("["); + + for (uint64_t e = 0; e < length; e++) { + printf("%f", (float) x[shape::getIndexOffset(e, shapeInfo)]); + + if (e < length - 1) + printf(", "); + } + + printf("]\n"); + } + + template + static _CUDA_H void exec_print_device(LaunchContext &ctx, const void *special, const Nd4jLong *shapeInfo) { + print_device<<<1, 1, 1024, *ctx.getCudaStream()>>>(special, shapeInfo); + } + + void print_special(LaunchContext &ctx, const NDArray &array, const std::string &message) { + NDArray::prepareSpecialUse({}, {&array}); + + PointersManager pm(&ctx, "print_device"); + BUILD_SINGLE_SELECTOR(array.dataType(), exec_print_device, (ctx, array.getSpecialBuffer(), array.getSpecialShapeInfo()), LIBND4J_TYPES) + pm.synchronize(); + + NDArray::registerSpecialUse({}, {&array}); + } + } + } +} diff --git a/libnd4j/include/ops/declarable/helpers/helpers.h b/libnd4j/include/ops/declarable/helpers/helpers.h index f2e19063e..f3aebc7b7 100644 --- a/libnd4j/include/ops/declarable/helpers/helpers.h +++ b/libnd4j/include/ops/declarable/helpers/helpers.h @@ -41,6 +41,9 @@ #include #include #include + +#include + #endif // CUDACC #endif // LIBND4J_HELPERS_H diff --git a/libnd4j/include/ops/declarable/helpers/impl/sparse_to_dense.cpp b/libnd4j/include/ops/declarable/helpers/impl/sparse_to_dense.cpp new file mode 100644 index 000000000..e21499314 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/impl/sparse_to_dense.cpp @@ -0,0 +1,123 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + + +#include +#include +#include + +namespace nd4j { + namespace ops { + namespace helpers { + template + static void fill_(const void *vvalues, const void *vindices, void *voutput, const Nd4jLong *zShapeInfo, uint8_t rank, uint64_t length) { + auto values = reinterpret_cast(vvalues); + auto indices = reinterpret_cast(vindices); + auto output = reinterpret_cast(voutput); + + Nd4jLong coords[MAX_RANK]; + uint64_t pos = 0; + for (uint64_t e = 0L; e < length; e++) { + // indices come in blocks + for (uint8_t p = 0; p < rank; p++) { + coords[p] = indices[pos++]; + } + + // fill output at given coords with sparse value + output[shape::getOffset(zShapeInfo, coords)] = values[e]; + } + + } + + void compat_sparse_to_dense(const NDArray &values, const NDArray &indices, NDArray *def, NDArray &output) { + // make sure host buffer is updated + values.syncToHost(); + indices.syncToHost(); + + auto rank = output.rankOf(); + + if (output.isS()) { + // string case is not so trivial, since elements might, and probably will, have different sizes + auto numValues = values.lengthOf(); + auto numElements = output.lengthOf(); + + // first of all we calculate final buffer sizes and offsets + auto defaultLength = def == nullptr ? 0 : StringUtils::byteLength(*def); + auto valuesLength = StringUtils::byteLength(values); + auto bufferLength = defaultLength * (output.lengthOf() - numValues) + valuesLength; + auto headerLength = ShapeUtils::stringBufferHeaderRequirements(numElements); + + // now we make sure our output buffer can hold results + output.dataBuffer()->expand( bufferLength + headerLength); + + std::vector outputCoords(rank); + std::vector valueCoords(rank); + + auto offsetsBuffer = output.bufferAsT(); + auto dataBuffer = reinterpret_cast(offsetsBuffer + output.lengthOf()); + + offsetsBuffer[0] = 0; + + // getting initial value coords + for (int e = 0; e < rank; e++) + valueCoords[e] = indices.e(e); + + // write results individually + for (uint64_t e = 0; e < numElements; e++) { + auto vIndex = shape::coords2index(output.shapeInfo(), valueCoords.data()); + auto cLength = 0L; + std::string str; + if (vIndex == e) { + // we're writing down sparse value here + str = values.e(e); + } else { + // we're writing down default value if it exists + if (def != nullptr) + str = def->e(0); + else + str = ""; + } + + // TODO: make it unicode compliant + memcpy(&dataBuffer[offsetsBuffer[e]], str.c_str(), str.length()); + + // writing down offset + offsetsBuffer[e+1] = cLength; + } + } else { + // numeric case is trivial, since all elements have equal sizes + + // write out default values, if they are present + if (def != nullptr) { + output.assign(def); + + // make sure output is synced back + output.syncToHost(); + } + + // write out values + BUILD_DOUBLE_SELECTOR(values.dataType(), indices.dataType(), fill_, (values.getBuffer(), indices.getBuffer(), output.buffer(), output.getShapeInfo(), rank, values.lengthOf()), LIBND4J_TYPES, INDEXING_TYPES); + } + // copy back to device, if there's any + output.syncToDevice(); + } + } + } +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/print_variable.h b/libnd4j/include/ops/declarable/helpers/print_variable.h new file mode 100644 index 000000000..3521e38b9 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/print_variable.h @@ -0,0 +1,34 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef LIBND4J_PRINT_VARIABLE_H +#define LIBND4J_PRINT_VARIABLE_H + +#include + +namespace nd4j { + namespace ops { + namespace helpers { + void print_special(LaunchContext &ctx, const NDArray &array, const std::string &message = {}); + } + } +} + +#endif //LIBND4J_PRINT_VARIABLE_H diff --git a/libnd4j/include/ops/declarable/helpers/sparse_to_dense.h b/libnd4j/include/ops/declarable/helpers/sparse_to_dense.h new file mode 100644 index 000000000..8d00639de --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/sparse_to_dense.h @@ -0,0 +1,34 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SAMEDIFF_SPARSE_TO_DENSE_H +#define SAMEDIFF_SPARSE_TO_DENSE_H + +#include + +namespace nd4j { + namespace ops { + namespace helpers { + void compat_sparse_to_dense(const NDArray &values, const NDArray &indices, NDArray *def, NDArray &output); + } + } +} + +#endif //SAMEDIFF_SPARSE_TO_DENSE_H diff --git a/libnd4j/include/type_boilerplate.h b/libnd4j/include/type_boilerplate.h index bd235726a..af0fe369d 100644 --- a/libnd4j/include/type_boilerplate.h +++ b/libnd4j/include/type_boilerplate.h @@ -634,7 +634,7 @@ #define BROADCAST(NAME) nd4j::BroadcastOpsTuple::custom(nd4j::scalar::NAME, nd4j::pairwise::NAME, nd4j::broadcast::NAME) #define BROADCAST_BOOL(NAME) nd4j::BroadcastBoolOpsTuple::custom(nd4j::scalar::NAME, nd4j::pairwise::NAME, nd4j::broadcast::NAME) - +#define ALL_STRINGS nd4j::DataType::UTF8, nd4j::DataType::UTF16, nd4j::DataType::UTF32 #define ALL_INDICES nd4j::DataType::INT32, nd4j::DataType::INT64 #define ALL_INTS nd4j::DataType::INT8, nd4j::DataType::UINT8, nd4j::DataType::INT16, nd4j::DataType::UINT16, nd4j::DataType::INT32, nd4j::DataType::UINT32, nd4j::DataType::INT64, nd4j::DataType::UINT64 #define ALL_FLOATS nd4j::DataType::HALF, nd4j::DataType::FLOAT32, nd4j::DataType::DOUBLE, nd4j::DataType::BFLOAT16 diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp index 4cf10ed00..15aa5751c 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp @@ -810,9 +810,10 @@ TEST_F(DeclarableOpsTests12, pullRows_1) { #ifdef __CUDABLAS__ nativeStart[1] = (x.getContext()->getCudaStream()); #endif - - pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), - z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + pullRows(nativeStart, &xBuf, x.getShapeInfo(), x.getSpecialShapeInfo(), + &zBuf, z.getShapeInfo(), z.specialShapeInfo(), 4, pidx, xTadPack.platformShapeInfo(), xTadPack.platformOffsets(), zTadPack.platformShapeInfo(), zTadPack.platformOffsets()); @@ -844,8 +845,10 @@ TEST_F(DeclarableOpsTests12, pullRows_2) { #ifdef __CUDABLAS__ nativeStart[1] = (x.getContext()->getCudaStream()); #endif - pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + pullRows(nativeStart, &xBuf, x.getShapeInfo(), x.specialShapeInfo(), + &zBuf, z.getShapeInfo(), z.specialShapeInfo(), 4, pidx, xTadPack.platformShapeInfo(), xTadPack.platformOffsets(), zTadPack.platformShapeInfo(), zTadPack.platformOffsets()); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests17.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests17.cpp new file mode 100644 index 000000000..543043ebd --- /dev/null +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests17.cpp @@ -0,0 +1,94 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include + + +using namespace nd4j; + + +class DeclarableOpsTests17 : public testing::Test { +public: + + DeclarableOpsTests17() { + printf("\n"); + fflush(stdout); + } +}; + +TEST_F(DeclarableOpsTests17, test_sparse_to_dense_1) { + auto values = NDArrayFactory::create({1.f, 2.f, 3.f}); + auto shape = NDArrayFactory::create({3, 3}); + auto ranges = NDArrayFactory::create({0,0, 1,1, 2,2}); + auto def = NDArrayFactory::create(0.f); + auto exp = NDArrayFactory::create('c', {3, 3}, {1.f,0.f,0.f, 0.f,2.f,0.f, 0.f,0.f,3.f}); + + + nd4j::ops::compat_sparse_to_dense op; + auto result = op.execute({&ranges, &shape, &values, &def}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + delete result; +} + +TEST_F(DeclarableOpsTests17, test_sparse_to_dense_2) { + auto values = NDArrayFactory::string('c', {3}, {"alpha", "beta", "gamma"}); + auto shape = NDArrayFactory::create({3, 3}); + auto ranges = NDArrayFactory::create({0,0, 1,1, 2,2}); + auto def = NDArrayFactory::string("d"); + auto exp = NDArrayFactory::string('c', {3, 3}, {"alpha","d","d", "d","beta","d", "d","d","gamma"}); + + + nd4j::ops::compat_sparse_to_dense op; + auto result = op.execute({&ranges, &shape, &values, &def}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + delete result; +} + +TEST_F(DeclarableOpsTests17, test_compat_string_split_1) { + auto x = NDArrayFactory::string('c', {2}, {"first string", "second"}); + auto delimiter = NDArrayFactory::string(" "); + + auto exp0 = NDArrayFactory::create({0,0, 0,1, 1,0}); + auto exp1 = NDArrayFactory::string('c', {3}, {"first", "string", "second"}); + + nd4j::ops::compat_string_split op; + auto result = op.execute({&x, &delimiter}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_EQ(2, result->size()); + + auto z0 = result->at(0); + auto z1 = result->at(1); + + ASSERT_TRUE(exp0.isSameShape(z0)); + ASSERT_TRUE(exp1.isSameShape(z1)); + + ASSERT_EQ(exp0, *z0); + ASSERT_EQ(exp1, *z1); + + delete result; +} \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp new file mode 100644 index 000000000..93864af8c --- /dev/null +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp @@ -0,0 +1,52 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include + + +using namespace nd4j; + + +class DeclarableOpsTests18 : public testing::Test { +public: + + DeclarableOpsTests18() { + printf("\n"); + fflush(stdout); + } +}; + +TEST_F(DeclarableOpsTests18, test_bitcast_1) { + auto x = NDArrayFactory::create(0.23028551377579154); + auto z = NDArrayFactory::create(0); + auto e = NDArrayFactory::create(4597464930322771456L); + + nd4j::ops::bitcast op; + auto status = op.execute({&x}, {&z}, {}, {(Nd4jLong) nd4j::DataType::INT64}, {}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); +} \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp new file mode 100644 index 000000000..871bfe186 --- /dev/null +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp @@ -0,0 +1,40 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include + + +using namespace nd4j; + + +class DeclarableOpsTests19 : public testing::Test { +public: + + DeclarableOpsTests19() { + printf("\n"); + fflush(stdout); + } +}; \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp index 0e8db97ff..f058d9112 100644 --- a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp @@ -834,12 +834,17 @@ TEST_F(JavaInteropTests, Test_Reduce3_EdgeCase) { auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {0,1}); NDArray::prepareSpecialUse({&z}, {&x, &y, &dims}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer dimBuf(dims.dataBuffer()); - execReduce3Tad(extraPointers, 2, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + execReduce3Tad(extraPointers, 2, &xBuf, x.shapeInfo(), x.specialShapeInfo(), nullptr, - y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - dims.buffer(), dims.shapeInfo(), dims.specialBuffer(), dims.specialShapeInfo(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); + &yBuf, y.shapeInfo(), y.specialShapeInfo(), + &zBuf, z.shapeInfo(), z.specialShapeInfo(), + &dimBuf, dims.shapeInfo(), dims.specialShapeInfo(), packX.platformShapeInfo(), + packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); NDArray::registerSpecialUse({&z}, {&x, &y, &dims}); @@ -981,10 +986,14 @@ TEST_F(JavaInteropTests, Test_Mixed_Add_1) { NDArray::prepareSpecialUse({&arrayZ}, {&arrayX, &arrayY}); + OpaqueDataBuffer xBuf(arrayX.dataBuffer()); + OpaqueDataBuffer yBuf(arrayY.dataBuffer()); + OpaqueDataBuffer zBuf(arrayZ.dataBuffer()); + execPairwiseTransform(nullptr, pairwise::Add, - arrayX.buffer(), arrayX.shapeInfo(), arrayX.getSpecialBuffer(), arrayX.getSpecialShapeInfo(), - arrayY.buffer(), arrayY.shapeInfo(), arrayY.getSpecialBuffer(), arrayY.getSpecialShapeInfo(), - arrayZ.buffer(), arrayZ.shapeInfo(), arrayZ.getSpecialBuffer(), arrayZ.getSpecialShapeInfo(), + &xBuf, arrayX.shapeInfo(), arrayX.getSpecialShapeInfo(), + &yBuf, arrayY.shapeInfo(), arrayY.getSpecialShapeInfo(), + &zBuf, arrayZ.shapeInfo(), arrayZ.getSpecialShapeInfo(), nullptr); NDArray::registerSpecialUse({&arrayZ}, {&arrayX, &arrayY}); @@ -1220,10 +1229,10 @@ TEST_F(JavaInteropTests, test_bfloat16_rng) { auto z = NDArrayFactory::create('c', {10}); RandomGenerator rng(119, 323841120L); bfloat16 args[2] = {(bfloat16) 0.0f, (bfloat16) 1.0f}; - execRandom(nullptr, nd4j::random::Ops::UniformDistribution, &rng, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), args); + OpaqueDataBuffer zBuf(z.dataBuffer()); + execRandom(nullptr, nd4j::random::Ops::UniformDistribution, &rng, &zBuf, z.shapeInfo(), z.specialShapeInfo(), args); //z.printIndexedBuffer("z"); - ASSERT_TRUE(z.sumNumber().e(0) > 0); } @@ -1267,6 +1276,64 @@ TEST_F(JavaInteropTests, test_size_dtype_1) { ASSERT_EQ(e, z); } +TEST_F(JavaInteropTests, test_expandable_array_op_1) { + auto x = NDArrayFactory::string('c', {2}, {"first string", "second"}); + auto d = NDArrayFactory::string(" "); + + auto z0 = NDArrayFactory::create('c', {6}); + auto z1 = NDArrayFactory::string('c', {3}, {"", "", ""}); + + auto exp0 = NDArrayFactory::create({0,0, 0,1, 1,0}); + auto exp1 = NDArrayFactory::string('c', {3}, {"first", "string", "second"}); + + InteropDataBuffer iz0(z0.dataBuffer()); + InteropDataBuffer iz1(z1.dataBuffer()); + + Context ctx(1); + ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); + ctx.setInputArray(1, d.buffer(), d.shapeInfo(), d.specialBuffer(), d.specialShapeInfo()); + ctx.setOutputArray(0, &iz0, z0.shapeInfo(), z0.specialShapeInfo()); + ctx.setOutputArray(1, &iz1, z1.shapeInfo(), z1.specialShapeInfo()); + + nd4j::ops::compat_string_split op; + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(exp0, z0); + ASSERT_EQ(exp1, z1); +} + +TEST_F(JavaInteropTests, test_workspace_backed_arrays_1) { + if (!Environment::getInstance()->isCPU()) + return; + + auto x = NDArrayFactory::create('c', {4, 3, 4, 4}); + auto y = NDArrayFactory::create('c', {4, 3, 3, 3}); + auto z = NDArrayFactory::create('c', {4, 3, 4, 4}); + + double buffer[2048]; + + InteropDataBuffer ix(0, DataType::DOUBLE, false); + InteropDataBuffer iy(0, DataType::DOUBLE, false); + InteropDataBuffer iz(0, DataType::DOUBLE, false); + + // we're imitating workspace-managed array here + ix.setPrimary(buffer + 64, x.lengthOf()); + iy.setPrimary(buffer + 64 + x.lengthOf(), y.lengthOf()); + iz.setPrimary(buffer + 64 + x.lengthOf() + y.lengthOf(), z.lengthOf()); + + Context ctx(1); + ctx.setInputArray(0, &ix, x.shapeInfo(), x.specialShapeInfo()); + ctx.setInputArray(1, &iy, y.shapeInfo(), y.specialShapeInfo()); + ctx.setOutputArray(0, &iz, z.shapeInfo(), z.specialShapeInfo()); + + ctx.setIArguments({2, 2, 1, 1, 0, 0, 1, 1, 0, 0, 0}); + + nd4j::ops::maxpool2d_bp op; + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); +} + /* TEST_F(JavaInteropTests, Test_Results_Conversion_1) { auto pl = nd4j::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb"); diff --git a/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp index cb4d4d07d..f0b7628ee 100644 --- a/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp @@ -470,12 +470,16 @@ TEST_F(LegacyOpsTests, Reduce3_2) { auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {1}); NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer dimBuf(dim.dataBuffer()); execReduce3Tad(extraPointers, reduce3::CosineSimilarity, - x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + nullptr, &yBuf, y.shapeInfo(), y.specialShapeInfo(), + &zBuf, z.shapeInfo(), z.specialShapeInfo(), + &dimBuf, dim.shapeInfo(), dim.specialShapeInfo(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); NDArray::registerSpecialUse({&z}, {&x, &y, &dim}); @@ -506,14 +510,17 @@ TEST_F(LegacyOpsTests, Reduce3_3) { auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {1}); NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); - + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer dimBuf(dim.dataBuffer()); execReduce3Tad(extraPointers, reduce3::CosineDistance, - x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), nullptr, - y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(), + &yBuf, y.shapeInfo(), y.specialShapeInfo(), + &zBuf, z.shapeInfo(), z.specialShapeInfo(), + &dimBuf, dim.shapeInfo(), dim.specialShapeInfo(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); ASSERT_EQ(e, z); NDArray::registerSpecialUse({&z}, {&x, &y, &dim}); @@ -543,14 +550,17 @@ TEST_F(LegacyOpsTests, Reduce3_4) { auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {1}); NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); - + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer dimBuf(dim.dataBuffer()); execReduce3Tad(extraPointers, reduce3::CosineDistance, - x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), nullptr, - y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(), + &yBuf, y.shapeInfo(), y.specialShapeInfo(), + &zBuf, z.shapeInfo(), z.specialShapeInfo(), + &dimBuf, dim.shapeInfo(), dim.specialShapeInfo(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); // z.printIndexedBuffer("z"); @@ -583,13 +593,16 @@ TEST_F(LegacyOpsTests, Reduce3_5) { NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer dimBuf(dim.dataBuffer()); execReduce3Tad(extraPointers, reduce3::CosineDistance, - x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), nullptr, - y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(), + &yBuf, y.shapeInfo(), y.specialShapeInfo(), &zBuf, z.shapeInfo(), z.specialShapeInfo(), + &dimBuf, dim.shapeInfo(), dim.specialShapeInfo(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); NDArray::registerSpecialUse({&z}, {&x, &y, &dim}); @@ -615,10 +628,15 @@ TEST_F(LegacyOpsTests, test_Reduce3_All_1) { NDArray::prepareSpecialUse({&z}, {&x, &y}); - execReduce3All(extraPointers, reduce3::EuclideanDistance, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(), + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer dimBuf(dim.dataBuffer()); + + execReduce3All(extraPointers, reduce3::EuclideanDistance, &xBuf, x.shapeInfo(), x.specialShapeInfo(), + nullptr, &yBuf, y.shapeInfo(), y.specialShapeInfo(), + &zBuf, z.shapeInfo(), z.specialShapeInfo(), + &dimBuf, dim.shapeInfo(), dim.specialShapeInfo(), tadPackX.platformShapeInfo(), tadPackX.platformOffsets(), tadPackY.platformShapeInfo(), tadPackY.platformOffsets()); @@ -730,13 +748,16 @@ TEST_F(LegacyOpsTests, test_legacy_reduce_empty_4) { auto z = NDArrayFactory::create('c', {0, 2}); auto e = NDArrayFactory::create('c', {0, 2}); + InteropDataBuffer xdb(x.dataBuffer()); + InteropDataBuffer ddb(d.dataBuffer()); + InteropDataBuffer zdb(z.dataBuffer()); ::execReduceSame2(nullptr, reduce::SameOps::Sum, - x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + &xdb, x.shapeInfo(), x.specialShapeInfo(), nullptr, - z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - d.buffer(), d.shapeInfo(), d.specialBuffer(), d.specialShapeInfo()); + &zdb, z.shapeInfo(), z.specialShapeInfo(), + &ddb, d.shapeInfo(), d.specialShapeInfo()); } diff --git a/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp index e426eeb1f..42eb50be0 100644 --- a/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp @@ -119,13 +119,15 @@ TEST_F(NativeOpsTests, ExecIndexReduce_1) { #ifdef __CUDABLAS__ printf("Unsupported for cuda now.\n"); #else + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); ::execIndexReduceScalar(nullptr, indexreduce::IndexMax, - x.buffer(), x.shapeInfo(), - nullptr, nullptr, + &xBuf, x.shapeInfo(), nullptr, - exp.buffer(), exp.shapeInfo(), - nullptr, nullptr); + nullptr, + &expBuf, exp.shapeInfo(), + nullptr); ASSERT_TRUE(exp.e(0) == 4LL); #endif @@ -140,15 +142,18 @@ TEST_F(NativeOpsTests, ExecIndexReduce_2) { printf("Unsupported for cuda now.\n"); #else NDArray dimension = NDArrayFactory::create({}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimensionBuf(dimension.dataBuffer()); + ::execIndexReduce(nullptr, indexreduce::IndexMax, - x.buffer(), x.shapeInfo(), - nullptr, nullptr, + &xBuf, x.shapeInfo(), nullptr, nullptr, - exp.buffer(), exp.shapeInfo(), - nullptr, nullptr, - dimension.buffer(), dimension.shapeInfo(), - nullptr, nullptr); + &expBuf, exp.shapeInfo(), + nullptr, + &dimensionBuf, dimension.shapeInfo(), + nullptr); ASSERT_TRUE(exp.e(0) == 24LL); #endif @@ -166,16 +171,21 @@ TEST_F(NativeOpsTests, ExecBroadcast_1) { #else auto dimension = NDArrayFactory::create('c', {1}, {1}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + ::execBroadcast(nullptr, broadcast::Add, - x.buffer(), x.shapeInfo(), - nullptr, nullptr, - y.buffer(), y.shapeInfo(), - nullptr, nullptr, - exp.buffer(), exp.shapeInfo(), - nullptr, nullptr, - dimension.buffer(), dimension.shapeInfo(), - nullptr, nullptr); + &xBuf, x.shapeInfo(), + nullptr, + &yBuf, y.shapeInfo(), + nullptr, + &expBuf, exp.shapeInfo(), + nullptr, + &dimBuf, dimension.shapeInfo(), + nullptr); ASSERT_TRUE(exp.e(0) == 3.); #endif @@ -194,17 +204,18 @@ printf("Unsupported for cuda now.\n"); int dimd = 0; auto dimension = NDArrayFactory::create('c', {1}, {dimd}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + ::execBroadcastBool(nullptr, broadcast::EqualTo, - x.buffer(), x.shapeInfo(), - nullptr, nullptr, - y.buffer(), y.shapeInfo(), - nullptr, nullptr, - exp.buffer(), exp.shapeInfo(), - nullptr, nullptr, - nullptr, - dimension.buffer(), dimension.shapeInfo(), - nullptr, nullptr); + &xBuf, x.shapeInfo(), nullptr, + &yBuf, y.shapeInfo(), nullptr, + &expBuf, exp.shapeInfo(), nullptr, nullptr, + &dimBuf, dimension.shapeInfo(), + nullptr); ASSERT_TRUE(exp.e(1) && !exp.e(0)); #endif @@ -219,14 +230,15 @@ TEST_F(NativeOpsTests, ExecPairwise_1) { #ifdef __CUDABLAS__ printf("Unsupported for cuda now.\n"); #else + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + ::execPairwiseTransform(nullptr, pairwise::Add, - x.buffer(), x.shapeInfo(), - nullptr, nullptr, - y.buffer(), y.shapeInfo(), - nullptr, nullptr, - exp.buffer(), exp.shapeInfo(), - nullptr, nullptr, + &xBuf, x.shapeInfo(), nullptr, + &yBuf, y.shapeInfo(), nullptr, + &expBuf, exp.shapeInfo(), nullptr, nullptr); ASSERT_TRUE(exp.e(5) == 8.); #endif @@ -243,14 +255,15 @@ TEST_F(NativeOpsTests, ExecPairwise_2) { #ifdef __CUDABLAS__ printf("Unsupported for cuda now.\n"); #else + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + ::execPairwiseTransformBool(nullptr, pairwise::And, - x.buffer(), x.shapeInfo(), - nullptr, nullptr, - y.buffer(), y.shapeInfo(), - nullptr, nullptr, - exp.buffer(), exp.shapeInfo(), - nullptr, nullptr, + &xBuf, x.shapeInfo(), nullptr, + &yBuf, y.shapeInfo(), nullptr, + &expBuf, exp.shapeInfo(), nullptr, nullptr); ASSERT_TRUE(exp.e(5) && !exp.e(4)); #endif @@ -266,14 +279,14 @@ TEST_F(NativeOpsTests, ReduceTest_1) { printf("Unsupported for cuda now.\n"); #else auto dimension = NDArrayFactory::create('c', {1}, {1}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); ::execReduceFloat(nullptr, reduce::Mean, - x.buffer(), x.shapeInfo(), - nullptr, nullptr, + &xBuf, x.shapeInfo(), nullptr, nullptr, - exp.buffer(), exp.shapeInfo(), - nullptr, nullptr); + &expBuf, exp.shapeInfo(), nullptr); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Reduce Mean"); ASSERT_TRUE(exp.e(0) == 13.); @@ -289,14 +302,14 @@ TEST_F(NativeOpsTests, ReduceTest_2) { #ifdef __CUDABLAS__ printf("Unsupported for cuda now.\n"); #else + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); ::execReduceSame(nullptr, reduce::Sum, - x.buffer(), x.shapeInfo(), - nullptr, nullptr, + &xBuf, x.shapeInfo(), nullptr, nullptr, - exp.buffer(), exp.shapeInfo(), - nullptr, nullptr); + &expBuf, exp.shapeInfo(), nullptr); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Reduce Sum"); ASSERT_TRUE(exp.e(0) == 325.); @@ -312,14 +325,14 @@ TEST_F(NativeOpsTests, ReduceTest_3) { #ifdef __CUDABLAS__ printf("Unsupported for cuda now.\n"); #else + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); ::execReduceBool(nullptr, reduce::All, - x.buffer(), x.shapeInfo(), - nullptr, nullptr, + &xBuf, x.shapeInfo(), nullptr, nullptr, - exp.buffer(), exp.shapeInfo(), - nullptr, nullptr); + &expBuf, exp.shapeInfo(), nullptr); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Reduce All"); ASSERT_TRUE(exp.e(0) == true); @@ -335,14 +348,14 @@ TEST_F(NativeOpsTests, ReduceTest_4) { #ifdef __CUDABLAS__ printf("Unsupported for cuda now.\n"); #else + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); ::execReduceLong(nullptr, reduce::CountNonZero, - x.buffer(), x.shapeInfo(), - nullptr, nullptr, + &xBuf, x.shapeInfo(), nullptr, nullptr, - exp.buffer(), exp.shapeInfo(), - nullptr, nullptr); + &expBuf, exp.shapeInfo(), nullptr); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Reduce CountNonZero"); ASSERT_TRUE(exp.e(0) == 25LL); @@ -359,15 +372,16 @@ TEST_F(NativeOpsTests, ReduceTest_5) { printf("Unsupported for cuda now.\n"); #else auto dimension = NDArrayFactory::create({0, 1}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + ::execReduceLong2(nullptr, reduce::CountNonZero, - x.buffer(), x.shapeInfo(), - nullptr, nullptr, + &xBuf, x.shapeInfo(), x.specialShapeInfo(), nullptr, - exp.buffer(), exp.shapeInfo(), - nullptr, nullptr, - dimension.buffer(), dimension.shapeInfo(), - dimension.specialBuffer(), dimension.specialShapeInfo()); + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), + &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo()); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Reduce CountNonZero"); ASSERT_TRUE(exp.e(0) == 25LL); @@ -389,15 +403,17 @@ TEST_F(NativeOpsTests, ReduceTest_6) { x.p(10, 0); x.p(11, 0); x.p(15, 0); x.p(16, 0); x.p(17, 0); x.p(20, 0); x.p(21, 0); x.p(22, 0); x.p(23, 0); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + ::execReduceLong2(nullptr, reduce::CountNonZero, - x.buffer(), x.shapeInfo(), - nullptr, nullptr, + &xBuf, x.shapeInfo(), nullptr, nullptr, - exp.buffer(), exp.shapeInfo(), - nullptr, nullptr, - dimension.buffer(), dimension.shapeInfo(), - dimension.specialBuffer(), dimension.specialShapeInfo()); + &expBuf, exp.shapeInfo(), nullptr, + &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo()); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Reduce CountNonZero"); ASSERT_TRUE(exp.equalsTo(z)); @@ -421,15 +437,16 @@ TEST_F(NativeOpsTests, ReduceTest_7) { x.linspace(1.0); x.syncToDevice(); dimension.syncToHost(); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + ::execReduceFloat2(extra, reduce::Mean, - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), nullptr, - exp.buffer(), exp.shapeInfo(), - exp.specialBuffer(), exp.specialShapeInfo(), - dimension.buffer(), dimension.shapeInfo(), - dimension.specialBuffer(), dimension.specialShapeInfo()); + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), + &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo()); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Reduce Mean"); ASSERT_TRUE(exp.equalsTo(z)); @@ -453,16 +470,16 @@ TEST_F(NativeOpsTests, ReduceTest_8) { x.syncToDevice(); dimension.syncToHost(); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); ::execReduceSame2(extra, reduce::Sum, - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), nullptr, - z.buffer(), z.shapeInfo(), - z.specialBuffer(), z.specialShapeInfo(), - dimension.buffer(), dimension.shapeInfo(), - dimension.specialBuffer(), dimension.specialShapeInfo()); + &zBuf, z.shapeInfo(), z.specialShapeInfo(), + &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo()); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Reduce Sum"); ASSERT_TRUE(exp.equalsTo(z)); @@ -485,15 +502,17 @@ TEST_F(NativeOpsTests, ReduceTest_9) { x.syncToDevice(); dimension.syncToHost(); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + ::execReduceBool2(extra, reduce::All, - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), nullptr, - exp.buffer(), exp.shapeInfo(), - exp.specialBuffer(), exp.specialShapeInfo(), - dimension.buffer(), dimension.shapeInfo(), - dimension.specialBuffer(), dimension.specialShapeInfo()); + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), + &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo()); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Reduce All"); ASSERT_TRUE(exp.equalsTo(z)); @@ -518,15 +537,16 @@ TEST_F(NativeOpsTests, Reduce3Test_1) { y.assign(2.); x.syncToDevice(); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + ::execReduce3(extra, reduce3::Dot, - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), nullptr, - y.buffer(), y.shapeInfo(), - y.specialBuffer(), y.specialShapeInfo(), - exp.buffer(), exp.shapeInfo(), - exp.specialBuffer(), exp.specialShapeInfo()); + &yBuf, y.shapeInfo(), y.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo()); //z.printIndexedBuffer("Z"); //exp.printIndexedBuffer("Reduce3 Dot"); ASSERT_TRUE(exp.equalsTo(z)); @@ -551,15 +571,16 @@ TEST_F(NativeOpsTests, Reduce3Test_2) { y.assign(2.); x.syncToDevice(); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + ::execReduce3Scalar(extra, reduce3::Dot, - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), nullptr, - y.buffer(), y.shapeInfo(), - y.specialBuffer(), y.specialShapeInfo(), - exp.buffer(), exp.shapeInfo(), - exp.specialBuffer(), exp.specialShapeInfo()); + &yBuf, y.shapeInfo(), y.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo()); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Reduce3 Dot"); ASSERT_TRUE(exp.equalsTo(z)); @@ -585,17 +606,18 @@ TEST_F(NativeOpsTests, Reduce3Test_3) { x.syncToDevice(); dimension.syncToHost(); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + ::execReduce3Tad(extra, reduce3::Dot, - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), nullptr, - y.buffer(), y.shapeInfo(), - y.specialBuffer(), y.specialShapeInfo(), - exp.buffer(), exp.shapeInfo(), - exp.specialBuffer(), exp.specialShapeInfo(), - dimension.buffer(), dimension.shapeInfo(), - dimension.specialBuffer(), dimension.specialShapeInfo(), + &yBuf, y.shapeInfo(), y.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), + &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo(), nullptr, nullptr, nullptr, nullptr); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Reduce All"); @@ -630,17 +652,18 @@ TEST_F(NativeOpsTests, Reduce3Test_4) { auto hTADShapeInfoY = tadPackY.primaryShapeInfo(); auto hTADOffsetsY = tadPackY.primaryOffsets(); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + ::execReduce3All(extra, reduce3::Dot, - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), nullptr, - y.buffer(), y.shapeInfo(), - y.specialBuffer(), y.specialShapeInfo(), - exp.buffer(), exp.shapeInfo(), - exp.specialBuffer(), exp.specialShapeInfo(), - dimension.buffer(), dimension.shapeInfo(), - dimension.specialBuffer(), dimension.specialShapeInfo(), + &yBuf, y.shapeInfo(), y.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), + &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo(), hTADShapeInfoX, hTADOffsetsX, hTADShapeInfoY, hTADOffsetsY); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Reduce All"); @@ -667,14 +690,16 @@ TEST_F(NativeOpsTests, ScalarTest_1) { //y.assign(2.); x.syncToDevice(); z.syncToDevice(); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + ::execScalar(extra, scalar::Multiply, - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo(), - exp.buffer(), exp.shapeInfo(), - exp.specialBuffer(), exp.specialShapeInfo(), - y.buffer(), y.shapeInfo(), - y.specialBuffer(), y.specialShapeInfo(), nullptr); + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), + &yBuf, y.shapeInfo(), y.specialShapeInfo(), nullptr); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Reduce All"); ASSERT_TRUE(exp.equalsTo(z)); @@ -700,14 +725,16 @@ TEST_F(NativeOpsTests, ScalarTest_2) { //y.assign(2.); x.syncToDevice(); z.syncToDevice(); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + ::execScalarBool(extra, scalar::GreaterThan, - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo(), - exp.buffer(), exp.shapeInfo(), - exp.specialBuffer(), exp.specialShapeInfo(), - y.buffer(), y.shapeInfo(), - y.specialBuffer(), y.specialShapeInfo(), nullptr); + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), + &yBuf, y.shapeInfo(), y.specialShapeInfo(), nullptr); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Reduce All"); ASSERT_TRUE(exp.e(5) == z.e(5) && exp.e(15) != z.e(15)); @@ -726,13 +753,14 @@ TEST_F(NativeOpsTests, SummaryStatsScalarTest_1) { printf("Unsupported for CUDA platform yet.\n"); return; #endif + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + ::execSummaryStatsScalar(extra, variance::SummaryStatsVariance, - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), nullptr, - exp.buffer(), exp.shapeInfo(), - exp.specialBuffer(), exp.specialShapeInfo(), false); + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), false); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Standard Variance"); ASSERT_TRUE(exp.equalsTo(z)); @@ -751,13 +779,13 @@ TEST_F(NativeOpsTests, SummaryStatsScalarTest_2) { printf("Unsupported for CUDA platform yet.\n"); return; #endif + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); ::execSummaryStats(extra, variance::SummaryStatsVariance, - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), nullptr, - exp.buffer(), exp.shapeInfo(), - exp.specialBuffer(), exp.specialShapeInfo(), false); + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), false); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Standard Variance"); ASSERT_TRUE(exp.equalsTo(z)); @@ -777,15 +805,16 @@ TEST_F(NativeOpsTests, SummaryStatsScalarTest_3) { return; #endif auto dimensions = NDArrayFactory::create({0, 1}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimensions.dataBuffer()); + ::execSummaryStatsTad(extra, variance::SummaryStatsVariance, - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), nullptr, - exp.buffer(), exp.shapeInfo(), - exp.specialBuffer(), exp.specialShapeInfo(), - dimensions.buffer(), dimensions.shapeInfo(), - dimensions.specialBuffer(), dimensions.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), + &dimBuf, dimensions.shapeInfo(), dimensions.specialShapeInfo(), false, nullptr, nullptr); // x.printIndexedBuffer("Input"); @@ -807,13 +836,15 @@ TEST_F(NativeOpsTests, TransformTest_1) { return; #endif z.linspace(1.); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + ::execTransformFloat(extra, transform::Sqrt, - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo(), - - exp.buffer(), exp.shapeInfo(), - exp.specialBuffer(), exp.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), nullptr); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Sqrt is"); @@ -834,13 +865,15 @@ TEST_F(NativeOpsTests, TransformTest_2) { return; #endif z.linspace(1.); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + ::execTransformSame(extra, transform::Square, - z.buffer(), z.shapeInfo(), - z.specialBuffer(), z.specialShapeInfo(), - - exp.buffer(), exp.shapeInfo(), - exp.specialBuffer(), exp.specialShapeInfo(), + &zBuf, z.shapeInfo(), z.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), nullptr); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Square is"); @@ -864,13 +897,14 @@ TEST_F(NativeOpsTests, TransformTest_3) { z.assign(true); x.p(24, -25); z.p(24, false); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + ::execTransformBool(extra, transform::IsPositive, - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo(), - - exp.buffer(), exp.shapeInfo(), - exp.specialBuffer(), exp.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), nullptr); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("IsPositive"); @@ -894,13 +928,13 @@ TEST_F(NativeOpsTests, TransformTest_4) { return; #endif //z.linspace(1.); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + ::execTransformStrict(extra, transform::Cosine, - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo(), - - exp.buffer(), exp.shapeInfo(), - exp.specialBuffer(), exp.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), nullptr); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Cosine"); @@ -932,17 +966,18 @@ TEST_F(NativeOpsTests, ScalarTadTest_1) { auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), dimensions, dimension.lengthOf()); auto tadPackZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.shapeInfo(), dimensions, dimension.lengthOf()); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + ::execScalarTad(extra, scalar::Multiply, - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo(), - exp.buffer(), exp.shapeInfo(), - exp.specialBuffer(), exp.specialShapeInfo(), - y.buffer(), y.shapeInfo(), - y.specialBuffer(), y.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), + &yBuf, y.shapeInfo(), y.specialShapeInfo(), nullptr, - dimension.buffer(), dimension.shapeInfo(), - dimension.specialBuffer(), dimension.specialShapeInfo(), + &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo(), tadPackX.primaryShapeInfo(), tadPackX.primaryOffsets(), tadPackZ.primaryShapeInfo(), tadPackZ.primaryOffsets()); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("Reduce All"); @@ -977,17 +1012,21 @@ TEST_F(NativeOpsTests, ScalarTadTest_2) { auto tadPackZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.shapeInfo(), dimensions, dimension.lengthOf()); z.assign(true); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + ::execScalarBoolTad(extra, scalar::And, - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo(), - exp.buffer(), exp.shapeInfo(), - exp.specialBuffer(), exp.specialShapeInfo(), - y.buffer(), y.shapeInfo(), - y.specialBuffer(), y.specialShapeInfo(), + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + &expBuf, exp.shapeInfo(), + exp.specialShapeInfo(), + &yBuf, y.shapeInfo(), + y.specialShapeInfo(), nullptr, - dimension.buffer(), dimension.shapeInfo(), - dimension.specialBuffer(), dimension.specialShapeInfo(), + &dimBuf, dimension.shapeInfo(), + dimension.specialShapeInfo(), tadPackX.primaryShapeInfo(), tadPackX.primaryOffsets(), tadPackZ.primaryShapeInfo(), tadPackZ.primaryOffsets()); // x.printIndexedBuffer("Input"); // exp.printIndexedBuffer("And"); @@ -1095,9 +1134,11 @@ TEST_F(NativeOpsTests, PullRowsTest_1) { #ifdef __CUDABLAS__ nativeStart[1] = (x.getContext()->getCudaStream()); #endif + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); - pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), - z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + pullRows(nativeStart, &xBuf, x.getShapeInfo(), x.getSpecialShapeInfo(), + &zBuf, z.getShapeInfo(), z.specialShapeInfo(), 4, pidx, xTadPack.platformShapeInfo(), xTadPack.platformOffsets(), zTadPack.platformShapeInfo(), zTadPack.platformOffsets()); @@ -1250,7 +1291,9 @@ TEST_F(NativeOpsTests, RandomTest_1) { #endif graph::RandomGenerator rng(1023, 119); double p = 0.5; - ::execRandom(extra, random::BernoulliDistribution, &rng, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &p); + OpaqueDataBuffer zBuf(z.dataBuffer()); + + ::execRandom(extra, random::BernoulliDistribution, &rng, &zBuf, z.shapeInfo(), z.specialShapeInfo(), &p); } TEST_F(NativeOpsTests, RandomTest_2) { @@ -1264,7 +1307,10 @@ TEST_F(NativeOpsTests, RandomTest_2) { x.linspace(0, 0.01); graph::RandomGenerator rng(1023, 119); double p = 0.5; - ::execRandom2(extra, random::DropOut, &rng, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &p); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + + ::execRandom2(extra, random::DropOut, &rng, &xBuf, x.shapeInfo(), x.specialShapeInfo(), &zBuf, z.shapeInfo(), z.specialShapeInfo(), &p); } TEST_F(NativeOpsTests, RandomTest_3) { @@ -1280,7 +1326,12 @@ TEST_F(NativeOpsTests, RandomTest_3) { x.linspace(1, -0.01); graph::RandomGenerator rng(1023, 119); double p = 0.5; - ::execRandom3(extra, random::ProbablisticMerge, &rng, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &p); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + + ::execRandom3(extra, random::ProbablisticMerge, &rng, &xBuf, x.shapeInfo(), x.specialShapeInfo(), &yBuf, + y.shapeInfo(), y.specialShapeInfo(), &zBuf, z.shapeInfo(), z.specialShapeInfo(), &p); } TEST_F(NativeOpsTests, RandomTest_4) { @@ -1316,6 +1367,10 @@ TEST_F(NativeOpsTests, SortTests_2) { #ifdef __CUDABLAS__ extras[1] = LaunchContext::defaultContext()->getCudaStream(); #endif +// OpaqueDataBuffer xBuf(x.dataBuffer()); +// OpaqueDataBuffer yBuf(y.dataBuffer()); +// OpaqueDataBuffer expBuf(exp.dataBuffer()); +// OpaqueDataBuffer dimBuf(exp.dataBuffer()); ::sortByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false); k.tickWriteDevice(); @@ -1541,6 +1596,13 @@ TEST_F(NativeOpsTests, CalculateOutputShapeTests_2) { ::deleteShapeList((Nd4jPointer) shapeList); } + +TEST_F(NativeOpsTests, interop_databuffer_tests_1) { + auto idb = ::allocateDataBuffer(100, 10, false); + auto ptr = ::dbPrimaryBuffer(idb); + ::deleteDataBuffer(idb); +} + //Uncomment when needed only - massive calculations //TEST_F(NativeOpsTests, BenchmarkTests_1) { // diff --git a/libnd4j/tests_cpu/layers_tests/StringTests.cpp b/libnd4j/tests_cpu/layers_tests/StringTests.cpp index 9f9569b92..ec7821f21 100644 --- a/libnd4j/tests_cpu/layers_tests/StringTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/StringTests.cpp @@ -90,4 +90,26 @@ TEST_F(StringTests, Basic_dup_1) { ASSERT_EQ(f, z1); delete dup; +} + +TEST_F(StringTests, byte_length_test_1) { + std::string f("alpha"); + auto array = NDArrayFactory::string(f); + + ASSERT_EQ(f.length(), StringUtils::byteLength(array)); +} + +TEST_F(StringTests, byte_length_test_2) { + auto array = NDArrayFactory::string('c', {2}, {"alpha", "beta"}); + + ASSERT_EQ(9, StringUtils::byteLength(array)); +} + +TEST_F(StringTests, test_split_1) { + auto split = StringUtils::split("alpha beta gamma", " "); + + ASSERT_EQ(3, split.size()); + ASSERT_EQ(std::string("alpha"), split[0]); + ASSERT_EQ(std::string("beta"), split[1]); + ASSERT_EQ(std::string("gamma"), split[2]); } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ExecDebuggingListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ExecDebuggingListener.java index a9862d253..d0501454c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ExecDebuggingListener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ExecDebuggingListener.java @@ -1,5 +1,6 @@ package org.nd4j.autodiff.listeners.debugging; +import lombok.val; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.listeners.At; import org.nd4j.autodiff.listeners.BaseListener; @@ -113,16 +114,16 @@ public class ExecDebuggingListener extends BaseListener { if(co.tArgs() != null && co.tArgs().length > 0) { sb.append("\n\ttArgs=").append(Arrays.toString(co.tArgs())); } - INDArray[] inputs = co.inputArguments(); - INDArray[] outputs = co.outputArguments(); + val inputs = co.inputArguments(); + val outputs = co.outputArguments(); if(inputs != null ) { - for (int i = 0; i < inputs.length; i++) { - sb.append("\n\tInput[").append(i).append("]=").append(inputs[i].shapeInfoToString()); + for (int i = 0; i < inputs.size(); i++) { + sb.append("\n\tInput[").append(i).append("]=").append(inputs.get(i).shapeInfoToString()); } } if(outputs != null ) { - for (int i = 0; i < outputs.length; i++) { - sb.append("\n\tOutputs[").append(i).append("]=").append(outputs[i].shapeInfoToString()); + for (int i = 0; i < outputs.size(); i++) { + sb.append("\n\tOutputs[").append(i).append("]=").append(outputs.get(i).shapeInfoToString()); } } } else { @@ -156,22 +157,22 @@ public class ExecDebuggingListener extends BaseListener { if(co.tArgs() != null && co.tArgs().length > 0 ){ sb.append("op.addTArgument(").append(Arrays.toString(co.tArgs()).replaceAll("[\\[\\]]", "")).append(");\n"); } - INDArray[] inputs = co.inputArguments(); - INDArray[] outputs = co.outputArguments(); + val inputs = co.inputArguments(); + val outputs = co.outputArguments(); if(inputs != null ) { - sb.append("INDArray[] inputs = new INDArray[").append(inputs.length).append("];\n"); - for (int i = 0; i < inputs.length; i++) { + sb.append("INDArray[] inputs = new INDArray[").append(inputs.size()).append("];\n"); + for (int i = 0; i < inputs.size(); i++) { sb.append("inputs[").append(i).append("] = "); - sb.append(createString(inputs[i])) + sb.append(createString(inputs.get(i))) .append(";\n"); } sb.append("op.addInputArgument(inputs);\n"); } if(outputs != null ) { - sb.append("INDArray[] outputs = new INDArray[").append(outputs.length).append("];\n"); - for (int i = 0; i < outputs.length; i++) { + sb.append("INDArray[] outputs = new INDArray[").append(outputs.size()).append("];\n"); + for (int i = 0; i < outputs.size(); i++) { sb.append("outputs[").append(i).append("] = "); - sb.append(createString(outputs[i])) + sb.append(createString(outputs.get(i))) .append(";\n"); } sb.append("op.addOutputArgument(outputs);\n"); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java index 7640d450c..bd35650dd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java @@ -478,11 +478,11 @@ public class InferenceSession extends AbstractSession { } throw new IllegalStateException(s); } - return ((Assert) op).outputArguments(); + return ((Assert) op).outputArguments().toArray(new INDArray[0]); } else if (op instanceof CustomOp) { CustomOp c = (CustomOp) op; Nd4j.exec(c); - return c.outputArguments(); + return c.outputArguments().toArray(new INDArray[0]); } else if (op instanceof Op) { Op o = (Op) op; Nd4j.exec(o); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java index fc7572180..d57ab7c97 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java @@ -457,7 +457,7 @@ public class OpValidation { for (int i = 0; i < testCase.testFns().size(); i++) { String error; try { - error = testCase.testFns().get(i).apply(testCase.op().outputArguments()[i]); + error = testCase.testFns().get(i).apply(testCase.op().outputArguments().get(i)); } catch (Throwable t) { throw new IllegalStateException("Exception thrown during op output validation for output " + i, t); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/listeners/NonInplaceValidationListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/listeners/NonInplaceValidationListener.java index 7e7a50ab2..9eee099a5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/listeners/NonInplaceValidationListener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/listeners/NonInplaceValidationListener.java @@ -1,6 +1,7 @@ package org.nd4j.autodiff.validation.listeners; import lombok.Getter; +import lombok.val; import org.nd4j.autodiff.listeners.At; import org.nd4j.autodiff.listeners.BaseListener; import org.nd4j.autodiff.listeners.Operation; @@ -50,12 +51,12 @@ public class NonInplaceValidationListener extends BaseListener { opInputs = new INDArray[]{o.x().dup(), o.y().dup()}; } } else if(op.getOp() instanceof DynamicCustomOp){ - INDArray[] arr = ((DynamicCustomOp) op.getOp()).inputArguments(); - opInputs = new INDArray[arr.length]; - opInputsOrig = new INDArray[arr.length]; - for( int i=0; i(in, dLdalpha); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index 77b946559..ab622b34f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -23,7 +23,6 @@ import com.google.flatbuffers.FlatBufferBuilder; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; -import net.ericaro.neoitertools.Generator; import org.apache.commons.math3.util.FastMath; import org.bytedeco.javacpp.BytePointer; import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper; @@ -998,14 +997,14 @@ public abstract class BaseNDArray implements INDArray, Iterable { } } - Pair tadInfo = - Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(this, dimension); + Pair tadInfo = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(this, dimension); DataBuffer shapeInfo = tadInfo.getFirst(); - val shape = Shape.shape(shapeInfo); - val stride = Shape.stride(shapeInfo).asLong(); + val jShapeInfo = shapeInfo.asLong(); + val shape = Shape.shape(jShapeInfo); + val stride = Shape.stride(jShapeInfo); long offset = offset() + tadInfo.getSecond().getLong(index); - val ews = shapeInfo.getLong(shapeInfo.getLong(0) * 2 + 2); - char tadOrder = (char) shapeInfo.getInt(shapeInfo.getLong(0) * 2 + 3); + val ews = shapeInfo.getLong(jShapeInfo[0] * 2 + 2); + char tadOrder = (char) shapeInfo.getInt(jShapeInfo[0] * 2 + 3); val toTad = Nd4j.create(data(), shape, stride, offset, ews, tadOrder); return toTad; } @@ -2217,9 +2216,10 @@ public abstract class BaseNDArray implements INDArray, Iterable { if(isEmpty() || isS()) return false; - return Shape.offset(jvmShapeInfo.javaShapeInformation) > 0 - || (length() < data().length() && data.dataType() != DataType.INT) - || data().originalDataBuffer() != null; + val c2 = (length() < data().length() && data.dataType() != DataType.INT); + val c3 = (data().originalDataBuffer() != null && data != data.originalDataBuffer()); + + return c2 || c3; } @Override @@ -3585,6 +3585,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { case DOUBLE: case FLOAT: case HALF: + case BFLOAT16: return getDouble(i); case LONG: case INT: @@ -3592,6 +3593,9 @@ public abstract class BaseNDArray implements INDArray, Iterable { case UBYTE: case BYTE: case BOOL: + case UINT64: + case UINT32: + case UINT16: return getLong(i); case UTF8: case COMPRESSED: @@ -4350,29 +4354,30 @@ public abstract class BaseNDArray implements INDArray, Iterable { //epsilon equals if (isScalar() && n.isScalar()) { - if (data.dataType() == DataType.FLOAT) { - double val = getDouble(0); - double val2 = n.getDouble(0); + if (isZ()) { + val val = getLong(0); + val val2 = n.getLong(0); + + return val == val2; + } else if (isR()) { + val val = getDouble(0); + val val2 = n.getDouble(0); if (Double.isNaN(val) != Double.isNaN(val2)) return false; return Math.abs(val - val2) < eps; - } else { - double val = getDouble(0); - double val2 = n.getDouble(0); + } else if (isB()) { + val val = getInt(0); + val val2 = n.getInt(0); - if (Double.isNaN(val) != Double.isNaN(val2)) - return false; - - return Math.abs(val - val2) < eps; + return val == val2; } } else if (isVector() && n.isVector()) { - - EqualsWithEps op = new EqualsWithEps(this, n, eps); - Nd4j.getExecutioner().exec(op); - double diff = op.z().getDouble(0); + val op = new EqualsWithEps(this, n, eps); + Nd4j.exec(op); + val diff = op.z().getDouble(0); return diff < 0.5; } @@ -4750,8 +4755,8 @@ public abstract class BaseNDArray implements INDArray, Iterable { return this; checkArrangeArray(rearrange); - int[] newShape = doPermuteSwap(shapeOf(), rearrange); - int[] newStride = doPermuteSwap(strideOf(), rearrange); + val newShape = doPermuteSwap(shape(), rearrange); + val newStride = doPermuteSwap(stride(), rearrange); char newOrder = Shape.getOrder(newShape, newStride, 1); @@ -4777,23 +4782,11 @@ public abstract class BaseNDArray implements INDArray, Iterable { return this; checkArrangeArray(rearrange); - val newShape = doPermuteSwap(Shape.shapeOf(shapeInfo), rearrange); - val newStride = doPermuteSwap(Shape.stride(shapeInfo), rearrange); + val newShape = doPermuteSwap(shape(), rearrange); + val newStride = doPermuteSwap(stride(), rearrange); char newOrder = Shape.getOrder(newShape, newStride, 1); - //Set the shape information of this array: shape, stride, order. - //Shape info buffer: [rank, [shape], [stride], offset, elementwiseStride, order] - /*for( int i=0; i outputArguments(); - - - INDArray[] outputArguments(); - - INDArray[] inputArguments(); + List inputArguments(); long[] iArgs(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java index 99e930176..e46dfab4b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java @@ -261,19 +261,13 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { } @Override - public INDArray[] outputArguments() { - if (!outputArguments.isEmpty()) { - return outputArguments.toArray(new INDArray[0]); - } - return new INDArray[0]; + public List outputArguments() { + return outputArguments; } @Override - public INDArray[] inputArguments() { - if (!inputArguments.isEmpty()) - return inputArguments.toArray(new INDArray[0]); - return new INDArray[0]; - + public List inputArguments() { + return inputArguments; } @Override @@ -367,10 +361,10 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { for (int i = 0; i < args.length; i++) { // it's possible to get into situation where number of args > number of arrays AT THIS MOMENT - if (i >= arrsSoFar.length) + if (i >= arrsSoFar.size()) continue; - if (!Arrays.equals(args[i].getShape(), arrsSoFar[i].shape())) + if (!Arrays.equals(args[i].getShape(), arrsSoFar.get(i).shape())) throw new ND4JIllegalStateException("Illegal array passed in as argument [" + i + "]. Expected shape " + Arrays.toString(args[i].getShape()) + " and received array with shape " + Arrays.toString(arg[i].shape())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/compat/CompatSparseToDense.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/compat/CompatSparseToDense.java new file mode 100644 index 000000000..18293c2ee --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/compat/CompatSparseToDense.java @@ -0,0 +1,51 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.compat; + +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * This is a wrapper for SparseToDense op that impelements corresponding TF operation + * + * @author raver119@gmail.com + */ +public class CompatSparseToDense extends DynamicCustomOp { + + public CompatSparseToDense() { + // + } + + public CompatSparseToDense(INDArray indices, INDArray shape, INDArray values) { + Preconditions.checkArgument(shape.isZ() && indices.isZ(), "Shape & indices arrays must have one integer data types"); + inputArguments.add(indices); + inputArguments.add(shape); + inputArguments.add(values); + } + + public CompatSparseToDense(INDArray indices, INDArray shape, INDArray values, INDArray defaultVaule) { + this(indices, shape, values); + Preconditions.checkArgument(defaultVaule.dataType() == values.dataType(), "Values array must have the same data type as defaultValue array"); + inputArguments.add(defaultVaule); + } + + @Override + public String opName() { + return "compat_sparse_to_dense"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/compat/CompatStringSplit.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/compat/CompatStringSplit.java new file mode 100644 index 000000000..33b6df4a6 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/compat/CompatStringSplit.java @@ -0,0 +1,51 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.compat; + +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * This is a wrapper for StringSplit op that impelements corresponding TF operation + * + * @author raver119@gmail.com + */ +public class CompatStringSplit extends DynamicCustomOp { + + public CompatStringSplit() { + // + } + + public CompatStringSplit(INDArray strings, INDArray delimiter) { + Preconditions.checkArgument(strings.isS() && delimiter.isS(), "Input arrays must have one of UTF types"); + inputArguments.add(strings); + inputArguments.add(delimiter); + } + + public CompatStringSplit(INDArray strings, INDArray delimiter, INDArray indices, INDArray values) { + this(strings, delimiter); + + outputArguments.add(indices); + outputArguments.add(values); + } + + @Override + public String opName() { + return "compat_string_split"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java index cb805a775..83020cb57 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java @@ -107,12 +107,12 @@ public class ScatterUpdate implements CustomOp { } @Override - public INDArray[] outputArguments() { + public List outputArguments() { return op.outputArguments(); } @Override - public INDArray[] inputArguments() { + public List inputArguments() { return op.inputArguments(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/DefaultOpExecutioner.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/DefaultOpExecutioner.java index 57606e452..aea251ebd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/DefaultOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/DefaultOpExecutioner.java @@ -23,7 +23,6 @@ import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.buffer.Utf8Buffer; import org.nd4j.linalg.api.environment.Nd4jEnvironment; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; @@ -172,7 +171,7 @@ public class DefaultOpExecutioner implements OpExecutioner { @Override public INDArray[] exec(CustomOp op) { - return execAndReturn(op).outputArguments(); + return execAndReturn(op).outputArguments().toArray(new INDArray[0]); } @Override @@ -822,7 +821,7 @@ public class DefaultOpExecutioner implements OpExecutioner { } @Override - public String getString(Utf8Buffer buffer, long index) { + public String getString(DataBuffer buffer, long index) { throw new UnsupportedOperationException(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/OpExecutioner.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/OpExecutioner.java index 1be417644..c4af57864 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/OpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/OpExecutioner.java @@ -20,7 +20,6 @@ import lombok.NonNull; import org.bytedeco.javacpp.Pointer; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.buffer.Utf8Buffer; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArrayStatistics; import org.nd4j.linalg.api.ops.*; @@ -32,8 +31,6 @@ import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.TadPack; import org.nd4j.linalg.cache.TADManager; -import org.nd4j.linalg.primitives.Pair; -import org.nd4j.linalg.profiler.OpProfiler; import org.nd4j.linalg.profiler.ProfilerConfig; import java.util.List; @@ -411,7 +408,7 @@ public interface OpExecutioner { * @param index * @return */ - String getString(Utf8Buffer buffer, long index); + String getString(DataBuffer buffer, long index); /** * Temporary hook diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/util/PrintAffinity.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/util/PrintAffinity.java new file mode 100644 index 000000000..d21e55916 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/util/PrintAffinity.java @@ -0,0 +1,43 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.util; + +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.factory.Nd4j; + +/** + * This is a wrapper for PrintAffinity op that just prints out affinity & locality status of INDArray + * + * @author raver119@gmail.com + */ +public class PrintAffinity extends DynamicCustomOp { + + public PrintAffinity() { + // + } + + public PrintAffinity(INDArray array) { + inputArguments.add(array); + } + + @Override + public String opName() { + return "print_affinity"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/util/PrintVariable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/util/PrintVariable.java new file mode 100644 index 000000000..abbf88f15 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/util/PrintVariable.java @@ -0,0 +1,66 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.util; + +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.factory.Nd4j; + +/** + * This is a wrapper for PrintVariable op that just prints out Variable to the stdout + * + * @author raver119@gmail.com + */ +public class PrintVariable extends DynamicCustomOp { + + public PrintVariable() { + // + } + + public PrintVariable(INDArray array, boolean printSpecial) { + inputArguments.add(array); + bArguments.add(printSpecial); + } + + public PrintVariable(INDArray array) { + this(array, false); + } + + public PrintVariable(INDArray array, String message, boolean printSpecial) { + this(array, Nd4j.create(message), printSpecial); + } + + public PrintVariable(INDArray array, String message) { + this(array, Nd4j.create(message), false); + } + + public PrintVariable(INDArray array, INDArray message, boolean printSpecial) { + this(array, printSpecial); + Preconditions.checkArgument(message.isS(), "Message argument should have String data type, but got [" + message.dataType() +"] instead"); + inputArguments.add(message); + } + + public PrintVariable(INDArray array, INDArray message) { + this(array, message, false); + } + + @Override + public String opName() { + return "print_variable"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/compression/CompressedDataBuffer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/compression/CompressedDataBuffer.java index c6e5bf904..107a68dd4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/compression/CompressedDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/compression/CompressedDataBuffer.java @@ -89,6 +89,11 @@ public class CompressedDataBuffer extends BaseDataBuffer { // no-op } + @Override + public Pointer addressPointer() { + return pointer; + } + /** * Drop-in replacement wrapper for BaseDataBuffer.read() method, aware of CompressedDataBuffer * @param s @@ -194,6 +199,15 @@ public class CompressedDataBuffer extends BaseDataBuffer { */ @Override public DataBuffer create(int[] data) { - throw new UnsupportedOperationException("This operation isn't supported for CompressedDataBuffer"); + throw new UnsupportedOperationException("This method isn't supported by CompressedDataBuffer"); + } + + public void pointerIndexerByCurrentType(DataType currentType) { + throw new UnsupportedOperationException("This method isn't supported by CompressedDataBuffer"); + } + + @Override + public DataBuffer reallocate(long length) { + throw new UnsupportedOperationException("This method isn't supported by CompressedDataBuffer"); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/convolution/Convolution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/convolution/Convolution.java index 9c0645156..ae26633e4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/convolution/Convolution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/convolution/Convolution.java @@ -98,7 +98,7 @@ public class Convolution { .build(); Nd4j.getExecutioner().execAndReturn(col2Im); - return col2Im.outputArguments()[0]; + return col2Im.outputArguments().get(0); } public static INDArray col2im(INDArray col, INDArray z, int sH, int sW, int pH, int pW, int kH, int kW, @@ -187,7 +187,7 @@ public class Convolution { .build()).build(); Nd4j.getExecutioner().execAndReturn(im2col); - return im2col.outputArguments()[0]; + return im2col.outputArguments().get(0); } public static INDArray im2col(INDArray img, int kh, int kw, int sy, int sx, int ph, int pw, int dH, int dW, boolean isSameMode, @@ -208,7 +208,7 @@ public class Convolution { .build()).build(); Nd4j.getExecutioner().execAndReturn(im2col); - return im2col.outputArguments()[0]; + return im2col.outputArguments().get(0); } /** @@ -298,7 +298,7 @@ public class Convolution { .build()).build(); Nd4j.getExecutioner().execAndReturn(im2col); - return im2col.outputArguments()[0]; + return im2col.outputArguments().get(0); } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index 94b17142b..dae946dba 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -40,7 +40,6 @@ import org.nd4j.graph.FlatArray; import org.nd4j.linalg.api.blas.params.MMulTranspose; import org.nd4j.linalg.api.buffer.*; import org.nd4j.linalg.api.buffer.factory.DataBufferFactory; -import org.nd4j.linalg.api.buffer.factory.DefaultDataBufferFactory; import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.concurrency.BasicAffinityManager; @@ -1044,16 +1043,7 @@ public class Nd4j { * @return the created buffer */ public static DataBuffer createBuffer(ByteBuffer buffer, DataType type, int length, long offset) { - switch (type) { - case INT: - return DATA_BUFFER_FACTORY_INSTANCE.createInt(offset, buffer, length); - case DOUBLE: - return DATA_BUFFER_FACTORY_INSTANCE.createDouble(offset, buffer, length); - case FLOAT: - return DATA_BUFFER_FACTORY_INSTANCE.createFloat(offset, buffer, length); - default: - throw new IllegalArgumentException("Illegal opType " + type); - } + return DATA_BUFFER_FACTORY_INSTANCE.create(buffer, type, length, offset); } /** @@ -1336,38 +1326,9 @@ public class Nd4j { * @return the created buffer */ public static DataBuffer createBuffer(ByteBuffer buffer, DataType type, int length) { - switch (type) { - case INT: - return DATA_BUFFER_FACTORY_INSTANCE.createInt(buffer, length); - case LONG: - return DATA_BUFFER_FACTORY_INSTANCE.createLong(buffer, length); - case DOUBLE: - return DATA_BUFFER_FACTORY_INSTANCE.createDouble(buffer, length); - case FLOAT: - return DATA_BUFFER_FACTORY_INSTANCE.createFloat(buffer, length); - case HALF: - return DATA_BUFFER_FACTORY_INSTANCE.createHalf(buffer, length); - default: - throw new IllegalArgumentException("Illegal opType " + type); - } + return createBuffer(buffer, type, length, 0); } - /** - * Create a buffer based on the data opType - * - * @param data the data to create the buffer with - * @return the created buffer - */ - public static DataBuffer createBuffer(byte[] data, int length) { - DataBuffer ret; - if (dataType() == DataType.DOUBLE) - ret = DATA_BUFFER_FACTORY_INSTANCE.createDouble(data, length); - else if (dataType() == DataType.HALF) - ret = DATA_BUFFER_FACTORY_INSTANCE.createHalf(data, length); - else - ret = DATA_BUFFER_FACTORY_INSTANCE.createFloat(data, length); - return ret; - } /** * Create a buffer equal of length prod(shape) @@ -2206,6 +2167,7 @@ public class Nd4j { private static String writeStringForArray(INDArray write) { if(write.isView() || !Shape.hasDefaultStridesForShape(write)) write = write.dup(); + String format = "0.000000000000000000E0"; return "{\n" + @@ -3927,16 +3889,6 @@ public class Nd4j { return create(shape, stride); } - /** - * Creates an ndarray with the specified shape - * - * @param rows the rows of the ndarray - * @param columns the columns of the ndarray - * @return the instance - */ - public static INDArray create(int rows, int columns) { - return create(rows, columns, order()); - } /** * Creates an ndarray with the specified shape @@ -4386,13 +4338,6 @@ public class Nd4j { return createUninitialized(shape, Nd4j.order()); } - /** - * See {@link #createUninitialized(long)} - */ - public static INDArray createUninitialized(int length) { - return createUninitialized((long)length); - } - /** * This method creates an *uninitialized* ndarray of specified length and default ordering. * @@ -4428,37 +4373,6 @@ public class Nd4j { ////////////////////// OTHER /////////////////////////////// - /** - * Creates a 2D array with specified number of rows, columns initialized with zero. - * - * @param rows number of rows. - * @param columns number of columns. - * @return the created array. - */ - public static INDArray zeros(long rows, long columns) { - return INSTANCE.zeros(rows, columns); - } - - /** - * Creates a 1D array with the specified number of columns initialized with zero. - * - * @param columns number of columns. - * @return the created array - */ - public static INDArray zeros(int columns) { - return INSTANCE.zeros(columns); - } - - /** - * Creates a 1D array with the specified data tyoe and number of columns initialized with zero. - * - * @param dataType data type. - * @param columns number of columns. - * @return the created array. - */ - public static INDArray zeros(DataType dataType, int columns) { - return INSTANCE.create(dataType, new long[]{columns}, 'c', Nd4j.getMemoryManager().getCurrentWorkspace()); - } /** * Creates an array with the specified data tyoe and shape initialized with zero. @@ -4468,7 +4382,10 @@ public class Nd4j { * @return the created array. */ public static INDArray zeros(DataType dataType, @NonNull long... shape) { - return INSTANCE.create(dataType, shape, 'c', Nd4j.getMemoryManager().getCurrentWorkspace()); + if(shape.length == 0) + return Nd4j.scalar(dataType, 0); + + return INSTANCE.create(dataType, shape, Nd4j.order(), Nd4j.getMemoryManager().getCurrentWorkspace()); } /** @@ -4588,31 +4505,6 @@ public class Nd4j { return INSTANCE.valueArrayOf(rows, columns, value); } - /** - * Creates a row vector with the specified number of columns - * - * @param rows the number of rows in the matrix - * @param columns the columns of the ndarray - * @return the created ndarray - */ - public static INDArray ones(int rows, int columns) { - return INSTANCE.ones(rows, columns); - } - - /** - * Create a 2D array with the given rows, columns and data type initialised with ones. - * - * @param dataType data type - * @param rows rows of the new array. - * @param columns columns of the new arrau. - * @return the created array - */ - public static INDArray ones(DataType dataType, int rows, int columns) { - INDArray ret = INSTANCE.createUninitialized(dataType, new long[]{rows, columns}, Nd4j.order(), Nd4j.getMemoryManager().getCurrentWorkspace()); - ret.assign(1); - return ret; - } - /** * Empty like * @@ -4817,8 +4709,7 @@ public class Nd4j { for (int idx : indexes) { if (idx < 0 || idx >= source.shape()[source.rank() - sourceDimension - 1]) { - throw new IllegalStateException( - "Index can't be < 0 and >= " + source.shape()[source.rank() - sourceDimension - 1]); + throw new IllegalStateException("Index can't be < 0 and >= " + source.shape()[source.rank() - sourceDimension - 1]); } } @@ -5186,7 +5077,7 @@ public class Nd4j { pp.toString(NDARRAY_FACTORY_CLASS)); Class convolutionInstanceClazz = (Class) Class .forName(pp.toString(CONVOLUTION_OPS, DefaultConvolutionInstance.class.getName())); - String defaultName = pp.toString(DATA_BUFFER_OPS, DefaultDataBufferFactory.class.getName()); + String defaultName = pp.toString(DATA_BUFFER_OPS, "org.nd4j.linalg.cpu.nativecpu.buffer.DefaultDataBufferFactory"); Class dataBufferFactoryClazz = (Class) Class .forName(pp.toString(DATA_BUFFER_OPS, defaultName)); Class shapeInfoProviderClazz = (Class) Class @@ -5871,7 +5762,7 @@ public class Nd4j { arr[e] = sb.get(e + pos); } - val buffer = new Utf8Buffer(arr, prod); + val buffer = DATA_BUFFER_FACTORY_INSTANCE.createUtf8Buffer(arr, prod); return Nd4j.create(buffer, shapeOf); } catch (Exception e) { throw new RuntimeException(e); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/deallocation/DeallocatorService.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/deallocation/DeallocatorService.java index 30c68d578..9fae57705 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/deallocation/DeallocatorService.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/deallocation/DeallocatorService.java @@ -30,6 +30,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; /** * This class provides unified management for Deallocatable resources @@ -43,6 +44,8 @@ public class DeallocatorService { private Map referenceMap = new ConcurrentHashMap<>(); private List>> deviceMap = new ArrayList<>(); + private AtomicLong counter = new AtomicLong(0); + public DeallocatorService() { // we need to have at least 2 threads, but for CUDA we'd need at least numDevices threads, due to thread->device affinity int numDevices = Nd4j.getAffinityManager().getNumberOfDevices(); @@ -69,6 +72,10 @@ public class DeallocatorService { } } + public long nextValue() { + return counter.incrementAndGet(); + } + /** * This method adds Deallocatable object instance to tracking system * diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/jackson/shaded/NDArrayTextSerializer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/jackson/shaded/NDArrayTextSerializer.java index 5e966f850..c395959d3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/jackson/shaded/NDArrayTextSerializer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/jackson/shaded/NDArrayTextSerializer.java @@ -17,10 +17,10 @@ package org.nd4j.serde.jackson.shaded; -import org.nd4j.linalg.api.buffer.Utf8Buffer; + +import lombok.val; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; -import org.nd4j.serde.base64.Nd4jBase64; import org.nd4j.shade.jackson.core.JsonGenerator; import org.nd4j.shade.jackson.databind.JsonSerializer; import org.nd4j.shade.jackson.databind.SerializerProvider; @@ -77,10 +77,9 @@ public class NDArrayTextSerializer extends JsonSerializer { jg.writeNumber(v); break; case UTF8: - Utf8Buffer utf8B = ((Utf8Buffer)arr.data()); - long n = utf8B.getNumWords(); + val n = arr.length(); for( int j=0; j${cuda.version}-${cudnn.version}-${javacpp-presets.cuda.version} ${dependency.platform} + junit junit diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AllocationPoint.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AllocationPoint.java index f673a15d7..881d1e8b2 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AllocationPoint.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AllocationPoint.java @@ -19,6 +19,7 @@ package org.nd4j.jita.allocator.impl; import lombok.Getter; import lombok.NonNull; import lombok.Setter; +import lombok.val; import org.bytedeco.javacpp.Pointer; import org.nd4j.jita.allocator.enums.AllocationStatus; import org.nd4j.jita.allocator.garbage.GarbageBufferReference; @@ -29,9 +30,11 @@ import org.nd4j.jita.allocator.time.providers.MillisecondsProvider; import org.nd4j.jita.allocator.time.providers.OperativeProvider; import org.nd4j.linalg.api.buffer.BaseDataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.jcublas.context.CudaContext; import org.nd4j.nativeblas.NativeOps; import org.nd4j.nativeblas.NativeOpsHolder; +import org.nd4j.nativeblas.OpaqueDataBuffer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -54,8 +57,8 @@ import java.util.concurrent.locks.ReentrantLock; public class AllocationPoint { private static Logger log = LoggerFactory.getLogger(AllocationPoint.class); - // thread safety is guaranteed by cudaLock - private volatile PointersPair pointerInfo; + @Getter + private OpaqueDataBuffer ptrDataBuffer; @Getter @Setter @@ -104,33 +107,27 @@ public class AllocationPoint { */ private volatile int deviceId; - public AllocationPoint() { - // + private long bytes; + + public AllocationPoint(@NonNull OpaqueDataBuffer opaqueDataBuffer, long bytes) { + ptrDataBuffer = opaqueDataBuffer; + this.bytes = bytes; + objectId = Nd4j.getDeallocatorService().nextValue(); } - public void acquireLock() { - //lock.lock(); - } - - public void releaseLock() { - //lock.unlock(); + public void setPointers(Pointer primary, Pointer special, long numberOfElements) { + NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetPrimaryBuffer(ptrDataBuffer, primary, numberOfElements); + NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetSpecialBuffer(ptrDataBuffer, special, numberOfElements); } public int getDeviceId() { - return deviceId; + return ptrDataBuffer.deviceId(); } public void setDeviceId(int deviceId) { - this.deviceId = deviceId; + NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetDeviceId(ptrDataBuffer, deviceId); } - /* - We assume 1D memory chunk allocations. - */ - @Getter - @Setter - private AllocationShape shape; - private AtomicBoolean enqueued = new AtomicBoolean(false); @Getter @@ -164,7 +161,7 @@ public class AllocationPoint { } public long getNumberOfBytes() { - return shape.getNumberOfBytes(); + return bytes; } /* @@ -220,67 +217,25 @@ public class AllocationPoint { * This method returns CUDA pointer object for this allocation. * It can be either device pointer or pinned memory pointer, or null. * - * PLEASE NOTE: Thread safety is guaranteed by reentrant read/write lock * @return */ public Pointer getDevicePointer() { - if (pointerInfo == null) { - log.info("pointerInfo is null"); - return null; - } - return pointerInfo.getDevicePointer(); + return NativeOpsHolder.getInstance().getDeviceNativeOps().dbSpecialBuffer(ptrDataBuffer); } /** * This method returns CUDA pointer object for this allocation. * It can be either device pointer or pinned memory pointer, or null. * - * PLEASE NOTE: Thread safety is guaranteed by reentrant read/write lock * @return */ public Pointer getHostPointer() { - if (pointerInfo == null) - return null; - - return pointerInfo.getHostPointer(); - } - - /** - * This method sets CUDA pointer for this allocation. - * It can be either device pointer, or pinned memory pointer, or null. - * - * PLEASE NOTE: Thread safety is guaranteed by reentrant read/write lock - * @param pointerInfo CUDA pointers wrapped into DevicePointerInfo - */ - public void setPointers(@NonNull PointersPair pointerInfo) { - this.pointerInfo = pointerInfo; - } - - public PointersPair getPointers() { - return this.pointerInfo; + return NativeOpsHolder.getInstance().getDeviceNativeOps().dbPrimaryBuffer(ptrDataBuffer); } public synchronized void tickDeviceRead() { - // this.deviceTicks.incrementAndGet(); - // this.timerShort.triggerEvent(); - // this.timerLong.triggerEvent(); - //this.deviceAccessTime.set(realTimeProvider.getCurrentTime()); - this.accessDeviceRead = (timeProvider.getCurrentTime()); - } - - - /** - * Returns time, in milliseconds, when this point was accessed on host side - * - * @return - */ - public synchronized long getHostReadTime() { - return accessHostRead; - }; - - public synchronized long getHostWriteTime() { - return accessHostWrite; + NativeOpsHolder.getInstance().getDeviceNativeOps().dbTickDeviceRead(ptrDataBuffer); } /** @@ -302,7 +257,7 @@ public class AllocationPoint { } public synchronized void tickHostRead() { - accessHostRead = (timeProvider.getCurrentTime()); + NativeOpsHolder.getInstance().getDeviceNativeOps().dbTickHostRead(ptrDataBuffer); } /** @@ -310,17 +265,14 @@ public class AllocationPoint { * */ public synchronized void tickDeviceWrite() { - // deviceAccessTime.set(realTimeProvider.getCurrentTime()); - tickDeviceRead(); - accessDeviceWrite = (timeProvider.getCurrentTime()); + NativeOpsHolder.getInstance().getDeviceNativeOps().dbTickDeviceWrite(ptrDataBuffer); } /** * This method sets time when this point was changed on host */ public synchronized void tickHostWrite() { - tickHostRead(); - accessHostWrite = (timeProvider.getCurrentTime()); + NativeOpsHolder.getInstance().getDeviceNativeOps().dbTickHostWrite(ptrDataBuffer); } /** @@ -329,10 +281,8 @@ public class AllocationPoint { * @return true, if data is actual, false otherwise */ public synchronized boolean isActualOnHostSide() { - boolean result = accessHostWrite >= accessDeviceWrite - || accessHostRead >= accessDeviceWrite; - - return result; + val s = NativeOpsHolder.getInstance().getDeviceNativeOps().dbLocality(ptrDataBuffer); + return s <= 0; } /** @@ -341,9 +291,8 @@ public class AllocationPoint { * @return */ public synchronized boolean isActualOnDeviceSide() { - boolean result = accessDeviceWrite >= accessHostWrite - || accessDeviceRead >= accessHostWrite; - return result; + val s = NativeOpsHolder.getInstance().getDeviceNativeOps().dbLocality(ptrDataBuffer); + return s >= 0; } /** @@ -355,6 +304,6 @@ public class AllocationPoint { @Override public String toString() { - return "AllocationPoint{" + "deviceId=" + deviceId + ", objectId=" + objectId + ", shape=" + shape + '}'; + return "AllocationPoint{" + "deviceId=" + deviceId + ", objectId=" + objectId + "}"; } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java index 8ec8734f7..ac35d1933 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java @@ -19,12 +19,10 @@ package org.nd4j.jita.allocator.impl; import lombok.Getter; import lombok.NonNull; import lombok.val; -import org.apache.commons.lang3.RandomUtils; import org.bytedeco.javacpp.Pointer; import org.nd4j.jita.allocator.Allocator; import org.nd4j.jita.allocator.enums.Aggressiveness; import org.nd4j.jita.allocator.enums.AllocationStatus; -import org.nd4j.jita.allocator.garbage.GarbageBufferReference; import org.nd4j.jita.allocator.pointers.CudaPointer; import org.nd4j.jita.allocator.pointers.PointersPair; import org.nd4j.jita.allocator.time.Ring; @@ -37,29 +35,25 @@ import org.nd4j.jita.flow.FlowController; import org.nd4j.jita.handler.MemoryHandler; import org.nd4j.jita.handler.impl.CudaZeroHandler; import org.nd4j.jita.workspace.CudaWorkspace; -import org.nd4j.linalg.api.buffer.BaseDataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.buffer.Utf8Buffer; import org.nd4j.linalg.api.memory.enums.MemoryKind; -import org.nd4j.linalg.api.memory.pointers.PagedPointer; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.cache.ConstantHandler; import org.nd4j.linalg.compression.CompressedDataBuffer; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer; +import org.nd4j.linalg.jcublas.buffer.CudaUtf8Buffer; import org.nd4j.linalg.jcublas.context.CudaContext; import org.nd4j.nativeblas.NativeOpsHolder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer; -import java.lang.ref.ReferenceQueue; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.locks.LockSupport; import java.util.concurrent.locks.ReentrantReadWriteLock; /** @@ -285,16 +279,10 @@ public class AtomicAllocator implements Allocator { */ @Override public Pointer getPointer(@NonNull DataBuffer buffer, CudaContext context) { - if (buffer instanceof Utf8Buffer) - return null; - return memoryHandler.getDevicePointer(buffer, context); } public Pointer getPointer(DataBuffer buffer) { - if (buffer instanceof Utf8Buffer) - return null; - return memoryHandler.getDevicePointer(buffer, getDeviceContext()); } @@ -320,7 +308,7 @@ public class AtomicAllocator implements Allocator { public Pointer getPointer(INDArray array, CudaContext context) { // DataBuffer buffer = array.data().originalDataBuffer() == null ? array.data() : array.data().originalDataBuffer(); if (array.isEmpty() || array.isS()) - return null; + throw new UnsupportedOperationException("Pew-pew"); return memoryHandler.getDevicePointer(array.data(), context); } @@ -372,20 +360,17 @@ public class AtomicAllocator implements Allocator { @Override public void synchronizeHostData(DataBuffer buffer) { // we don't want non-committed ops left behind - //Nd4j.getExecutioner().push(); + Nd4j.getExecutioner().commit(); - // we don't synchronize constant buffers, since we assume they are always valid on host side - if (buffer.isConstant() || buffer.dataType() == DataType.UTF8 || AtomicAllocator.getInstance().getAllocationPoint(buffer).getPointers().getHostPointer() == null) { - return; - } + val oPtr = NativeOpsHolder.getInstance().getDeviceNativeOps().dbPrimaryBuffer(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer()); - // we actually need synchronization only in device-dependant environment. no-op otherwise - if (memoryHandler.isDeviceDependant()) { - val point = getAllocationPoint(buffer.getTrackingPoint()); - if (point == null) - throw new RuntimeException("AllocationPoint is NULL"); - memoryHandler.synchronizeThreadDevice(Thread.currentThread().getId(), memoryHandler.getDeviceId(), point); - } + // we actually need synchronization only in device-dependant environment. no-op otherwise. managed by native code + NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToPrimary(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer()); + + val cPtr = NativeOpsHolder.getInstance().getDeviceNativeOps().dbPrimaryBuffer(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer()); + + //assert oPtr.address() == cPtr.address(); + //assert buffer.address() == oPtr.address(); } @@ -446,6 +431,7 @@ public class AtomicAllocator implements Allocator { public AllocationPoint pickExternalBuffer(DataBuffer buffer) { + /** AllocationPoint point = new AllocationPoint(); Long allocId = objectsTracker.getAndIncrement(); point.setObjectId(allocId); @@ -458,6 +444,9 @@ public class AtomicAllocator implements Allocator { point.tickHostRead(); return point; + */ + + throw new UnsupportedOperationException("Pew-pew"); } /** @@ -469,69 +458,8 @@ public class AtomicAllocator implements Allocator { * @param location */ @Override - public AllocationPoint allocateMemory(DataBuffer buffer, AllocationShape requiredMemory, AllocationStatus location, - boolean initialize) { - AllocationPoint point = new AllocationPoint(); - - useTracker.set(System.currentTimeMillis()); - - // we use these longs as tracking codes for memory tracking - Long allocId = objectsTracker.getAndIncrement(); - //point.attachBuffer(buffer); - point.setObjectId(allocId); - point.setShape(requiredMemory); - /* - if (buffer instanceof CudaIntDataBuffer) { - buffer.setConstant(true); - point.setConstant(true); - } - */ - /*int numBuckets = configuration.getNumberOfGcThreads(); - int bucketId = RandomUtils.nextInt(0, numBuckets); - - GarbageBufferReference reference = - new GarbageBufferReference((BaseDataBuffer) buffer, queueMap.get(bucketId), point);*/ - //point.attachReference(reference); - point.setDeviceId(-1); - - if (buffer.isAttached()) { - long reqMem = AllocationUtils.getRequiredMemory(requiredMemory); - - // workaround for init order - getMemoryHandler().getCudaContext(); - point.setDeviceId(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - - val workspace = (CudaWorkspace) Nd4j.getMemoryManager().getCurrentWorkspace(); - - val pair = new PointersPair(); - val ptrDev = workspace.alloc(reqMem, MemoryKind.DEVICE, requiredMemory.getDataType(), initialize); - - if (ptrDev != null) { - pair.setDevicePointer(ptrDev); - point.setAllocationStatus(AllocationStatus.DEVICE); - } else { - // we allocate initial host pointer only - val ptrHost = workspace.alloc(reqMem, MemoryKind.HOST, requiredMemory.getDataType(), initialize); - pair.setHostPointer(ptrHost); - - pair.setDevicePointer(ptrHost); - point.setAllocationStatus(AllocationStatus.HOST); - } - - point.setAttached(true); - - point.setPointers(pair); - } else { - // we stay naive on PointersPair, we just don't know on this level, which pointers are set. MemoryHandler will be used for that - PointersPair pair = memoryHandler.alloc(location, point, requiredMemory, initialize); - point.setPointers(pair); - } - - allocationsMap.put(allocId, point); - //point.tickHostRead(); - point.tickDeviceWrite(); - //point.setAllocationStatus(location); - return point; + public AllocationPoint allocateMemory(DataBuffer buffer, AllocationShape requiredMemory, AllocationStatus location, boolean initialize) { + throw new UnsupportedOperationException("Pew-pew"); } @@ -619,10 +547,11 @@ public class AtomicAllocator implements Allocator { */ if (point.getBuffer() == null) { purgeZeroObject(bucketId, object, point, false); - freeSpace.addAndGet(AllocationUtils.getRequiredMemory(point.getShape())); + //freeSpace.addAndGet(AllocationUtils.getRequiredMemory(point.getShape())); + throw new UnsupportedOperationException("Pew-pew"); - elementsDropped.incrementAndGet(); - continue; + //elementsDropped.incrementAndGet(); + //continue; } else { elementsSurvived.incrementAndGet(); } @@ -682,13 +611,14 @@ public class AtomicAllocator implements Allocator { if (point.getAllocationStatus() == AllocationStatus.DEVICE) { // we deallocate device memory purgeDeviceObject(threadId, deviceId, object, point, false); - freeSpace.addAndGet(AllocationUtils.getRequiredMemory(point.getShape())); + //freeSpace.addAndGet(AllocationUtils.getRequiredMemory(point.getShape())); // and we deallocate host memory, since object is dereferenced - purgeZeroObject(point.getBucketId(), object, point, false); + //purgeZeroObject(point.getBucketId(), object, point, false); - elementsDropped.incrementAndGet(); - continue; + //elementsDropped.incrementAndGet(); + //continue; + throw new UnsupportedOperationException("Pew-pew"); } ; } else { elementsSurvived.incrementAndGet(); @@ -1014,6 +944,31 @@ public class AtomicAllocator implements Allocator { this.memoryHandler.memcpy(dstBuffer, srcBuffer); } + @Override + public void tickHostWrite(DataBuffer buffer) { + getAllocationPoint(buffer).tickHostWrite(); + } + + @Override + public void tickHostWrite(INDArray array) { + getAllocationPoint(array.data()).tickHostWrite(); + } + + @Override + public void tickDeviceWrite(INDArray array) { + getAllocationPoint(array.data()).tickDeviceWrite(); + } + + @Override + public AllocationPoint getAllocationPoint(INDArray array) { + return getAllocationPoint(array.data()); + } + + @Override + public AllocationPoint getAllocationPoint(DataBuffer buffer) { + return ((BaseCudaDataBuffer) buffer).getAllocationPoint(); + } + /** * This method returns deviceId for current thread * All values >= 0 are considered valid device IDs, all values < 0 are considered stubs. @@ -1031,48 +986,6 @@ public class AtomicAllocator implements Allocator { return new CudaPointer(getDeviceId()); } - @Override - public void tickHostWrite(DataBuffer buffer) { - AllocationPoint point = getAllocationPoint(buffer.getTrackingPoint()); - point.tickHostWrite(); - } - - @Override - public void tickHostWrite(INDArray array) { - DataBuffer buffer = - array.data().originalDataBuffer() == null ? array.data() : array.data().originalDataBuffer(); - - tickHostWrite(buffer); - } - - @Override - public void tickDeviceWrite(INDArray array) { - DataBuffer buffer = - array.data().originalDataBuffer() == null ? array.data() : array.data().originalDataBuffer(); - AllocationPoint point = getAllocationPoint(buffer.getTrackingPoint()); - - point.tickDeviceWrite(); - } - - @Override - public AllocationPoint getAllocationPoint(INDArray array) { - if (array.isEmpty()) - return null; - - DataBuffer buffer = array.data().originalDataBuffer() == null ? array.data() : array.data().originalDataBuffer(); - return getAllocationPoint(buffer); - } - - @Override - public AllocationPoint getAllocationPoint(DataBuffer buffer) { - if (buffer instanceof CompressedDataBuffer) { - log.warn("Trying to get AllocationPoint from CompressedDataBuffer"); - throw new RuntimeException("AP CDB"); - } - - return getAllocationPoint(buffer.getTrackingPoint()); - } - @Override public void registerAction(CudaContext context, INDArray result, INDArray... operands) { memoryHandler.registerAction(context, result, operands); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/CudaDeallocator.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/CudaDeallocator.java index ae1ad93cd..0f65b8f00 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/CudaDeallocator.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/CudaDeallocator.java @@ -23,46 +23,21 @@ import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import org.nd4j.linalg.api.memory.Deallocator; +import org.nd4j.nativeblas.NativeOpsHolder; +import org.nd4j.nativeblas.OpaqueDataBuffer; @Slf4j public class CudaDeallocator implements Deallocator { - private AllocationPoint point; + private OpaqueDataBuffer opaqueDataBuffer; public CudaDeallocator(@NonNull BaseCudaDataBuffer buffer) { - this.point = buffer.getAllocationPoint(); - if (this.point == null) - throw new RuntimeException(); + opaqueDataBuffer = buffer.getOpaqueDataBuffer(); } @Override public void deallocate() { log.trace("Deallocating CUDA memory"); - // skipping any allocation that is coming from workspace - if (point.isAttached() || point.isReleased()) { - // TODO: remove allocation point as well? - if (!AtomicAllocator.getInstance().allocationsMap().containsKey(point.getObjectId())) - return; - - AtomicAllocator.getInstance().getFlowController().waitTillReleased(point); - - AtomicAllocator.getInstance().getFlowController().getEventsProvider().storeEvent(point.getLastWriteEvent()); - AtomicAllocator.getInstance().getFlowController().getEventsProvider().storeEvent(point.getLastReadEvent()); - - AtomicAllocator.getInstance().allocationsMap().remove(point.getObjectId()); - - return; - } - - - //log.info("Purging {} bytes...", AllocationUtils.getRequiredMemory(point.getShape())); - if (point.getAllocationStatus() == AllocationStatus.HOST) { - AtomicAllocator.getInstance().purgeZeroObject(point.getBucketId(), point.getObjectId(), point, false); - } else if (point.getAllocationStatus() == AllocationStatus.DEVICE) { - AtomicAllocator.getInstance().purgeDeviceObject(0L, point.getDeviceId(), point.getObjectId(), point, false); - - // and we deallocate host memory, since object is dereferenced - AtomicAllocator.getInstance().purgeZeroObject(point.getBucketId(), point.getObjectId(), point, false); - } + NativeOpsHolder.getInstance().getDeviceNativeOps().deleteDataBuffer(opaqueDataBuffer); } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/pointers/cuda/cudaStream_t.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/pointers/cuda/cudaStream_t.java index 8d78ee950..7d9bfb629 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/pointers/cuda/cudaStream_t.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/pointers/cuda/cudaStream_t.java @@ -17,6 +17,7 @@ package org.nd4j.jita.allocator.pointers.cuda; import lombok.NonNull; +import lombok.val; import org.bytedeco.javacpp.Pointer; import org.nd4j.jita.allocator.pointers.CudaPointer; import org.nd4j.linalg.exception.ND4JException; @@ -37,8 +38,9 @@ public class cudaStream_t extends CudaPointer { NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); int res = nativeOps.streamSynchronize(this); - if (nativeOps.lastErrorCode() != 0) - throw new RuntimeException(nativeOps.lastErrorMessage()); + val ec = nativeOps.lastErrorCode(); + if (ec != 0) + throw new RuntimeException(nativeOps.lastErrorMessage() + "; Error code: " + ec); return res; } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/constant/ProtectedCudaConstantHandler.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/constant/ProtectedCudaConstantHandler.java index 5548d854a..b08248bdb 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/constant/ProtectedCudaConstantHandler.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/constant/ProtectedCudaConstantHandler.java @@ -129,7 +129,7 @@ public class ProtectedCudaConstantHandler implements ConstantHandler { AllocationPoint point = AtomicAllocator.getInstance().getAllocationPoint(dataBuffer); - long requiredMemoryBytes = AllocationUtils.getRequiredMemory(point.getShape()); + long requiredMemoryBytes = point.getNumberOfBytes(); val originalBytes = requiredMemoryBytes; requiredMemoryBytes += 8 - (requiredMemoryBytes % 8); @@ -147,13 +147,13 @@ public class ProtectedCudaConstantHandler implements ConstantHandler { if (currentOffset + requiredMemoryBytes >= MAX_CONSTANT_LENGTH || requiredMemoryBytes > MAX_BUFFER_LENGTH) { if (point.getAllocationStatus() == AllocationStatus.HOST && CudaEnvironment.getInstance().getConfiguration().getMemoryModel() == Configuration.MemoryModel.DELAYED) { - AtomicAllocator.getInstance().getMemoryHandler().alloc(AllocationStatus.DEVICE, point, point.getShape(), - false); + //AtomicAllocator.getInstance().getMemoryHandler().alloc(AllocationStatus.DEVICE, point, point.getShape(), false); + throw new UnsupportedOperationException("Pew-pew"); } val profD = PerformanceTracker.getInstance().helperStartTransaction(); - if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(point.getPointers().getDevicePointer(), point.getPointers().getHostPointer(), originalBytes, 1, context.getSpecialStream()) == 0) { + if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(point.getDevicePointer(), point.getHostPointer(), originalBytes, 1, context.getSpecialStream()) == 0) { throw new ND4JIllegalStateException("memcpyAsync failed"); } flowController.commitTransfer(context.getSpecialStream()); @@ -176,14 +176,13 @@ public class ProtectedCudaConstantHandler implements ConstantHandler { if (currentOffset >= MAX_CONSTANT_LENGTH) { if (point.getAllocationStatus() == AllocationStatus.HOST && CudaEnvironment.getInstance().getConfiguration().getMemoryModel() == Configuration.MemoryModel.DELAYED) { - AtomicAllocator.getInstance().getMemoryHandler().alloc(AllocationStatus.DEVICE, point, point.getShape(), - false); + //AtomicAllocator.getInstance().getMemoryHandler().alloc(AllocationStatus.DEVICE, point, point.getShape(), false); + throw new UnsupportedOperationException("Pew-pew"); } val profD = PerformanceTracker.getInstance().helperStartTransaction(); - if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(point.getPointers().getDevicePointer(), point.getPointers().getHostPointer(), - originalBytes, 1, context.getSpecialStream()) == 0) { + if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(point.getDevicePointer(), point.getHostPointer(), originalBytes, 1, context.getSpecialStream()) == 0) { throw new ND4JIllegalStateException("memcpyAsync failed"); } flowController.commitTransfer(context.getSpecialStream()); @@ -202,8 +201,7 @@ public class ProtectedCudaConstantHandler implements ConstantHandler { - NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyConstantAsync(currentOffset, point.getPointers().getHostPointer(), originalBytes, 1, - context.getSpecialStream()); + NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyConstantAsync(currentOffset, point.getHostPointer(), originalBytes, 1, context.getSpecialStream()); flowController.commitTransfer(context.getSpecialStream()); long cAddr = deviceAddresses.get(deviceId).address() + currentOffset; @@ -212,7 +210,10 @@ public class ProtectedCudaConstantHandler implements ConstantHandler { // logger.info("copying to constant: {}, bufferLength: {}, bufferDtype: {}, currentOffset: {}, currentAddres: {}", requiredMemoryBytes, dataBuffer.length(), dataBuffer.dataType(), currentOffset, cAddr); point.setAllocationStatus(AllocationStatus.CONSTANT); - point.getPointers().setDevicePointer(new CudaPointer(cAddr)); + //point.setDevicePointer(new CudaPointer(cAddr)); + if (1 > 0) + throw new UnsupportedOperationException("Pew-pew"); + point.setConstant(true); point.tickDeviceWrite(); point.setDeviceId(deviceId); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/flow/impl/SynchronousFlowController.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/flow/impl/SynchronousFlowController.java index d81de381a..48e981491 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/flow/impl/SynchronousFlowController.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/flow/impl/SynchronousFlowController.java @@ -32,6 +32,7 @@ import org.nd4j.jita.conf.Configuration; import org.nd4j.jita.conf.CudaEnvironment; import org.nd4j.jita.flow.FlowController; import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.performance.PerformanceTracker; @@ -70,53 +71,12 @@ public class SynchronousFlowController implements FlowController { */ @Override public void synchronizeToHost(AllocationPoint point) { - - if (!point.isActualOnHostSide()) { - val context = allocator.getDeviceContext(); - - if (!point.isConstant()) - waitTillFinished(point); - - // if this piece of memory is device-dependant, we'll also issue copyback once - if (point.getAllocationStatus() == AllocationStatus.DEVICE && !point.isActualOnHostSide()) { - long perfD = PerformanceTracker.getInstance().helperStartTransaction(); - val bytes = AllocationUtils.getRequiredMemory(point.getShape()); - - if (nativeOps.memcpyAsync(point.getHostPointer(), point.getDevicePointer(), bytes, CudaConstants.cudaMemcpyDeviceToHost, context.getSpecialStream()) == 0) - throw new IllegalStateException("synchronizeToHost memcpyAsync failed: " + point.getShape()); - - commitTransfer(context.getSpecialStream()); - - PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), perfD, point.getNumberOfBytes(), MemcpyDirection.DEVICE_TO_HOST); - } - - // updating host read timer - point.tickHostRead(); - } + NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToPrimary(point.getPtrDataBuffer()); } @Override public void synchronizeToDevice(@NonNull AllocationPoint point) { - if (point.isConstant()) - return; - - if (!point.isActualOnDeviceSide()) { - if (point.getAllocationStatus() == AllocationStatus.DEVICE) { - val context = allocator.getDeviceContext(); - - long perfD = PerformanceTracker.getInstance().helperStartTransaction(); - - if (nativeOps.memcpyAsync(point.getDevicePointer(), point.getHostPointer(), - AllocationUtils.getRequiredMemory(point.getShape()), - CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream()) == 0) - throw new IllegalStateException("MemcpyAsync failed: " + point.getShape()); - - commitTransfer(context.getSpecialStream()); - point.tickDeviceRead(); - - PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), perfD, point.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE); - } - } + NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToSpecial(point.getPtrDataBuffer()); } @Override @@ -147,7 +107,6 @@ public class SynchronousFlowController implements FlowController { val pointData = allocator.getAllocationPoint(operand); val pointShape = allocator.getAllocationPoint(operand.shapeInfoDataBuffer()); - pointData.acquireLock(); if (pointData.getDeviceId() != cId && pointData.getDeviceId() >= 0) { DataBuffer buffer = operand.data().originalDataBuffer() == null ? operand.data() @@ -172,15 +131,12 @@ public class SynchronousFlowController implements FlowController { val cId = allocator.getDeviceId(); - if (result != null && !result.isEmpty() && !result.isS()) { + if (result != null && !result.isEmpty()) { Nd4j.getCompressor().autoDecompress(result); prepareDelayedMemory(result); val pointData = allocator.getAllocationPoint(result); val pointShape = allocator.getAllocationPoint(result.shapeInfoDataBuffer()); - pointData.acquireLock(); - - if (pointData.getDeviceId() != cId && pointData.getDeviceId() >= 0 && (!CudaEnvironment.getInstance().getConfiguration().isCrossDeviceAccessAllowed() || !NativeOpsHolder.getInstance().getDeviceNativeOps().isP2PAvailable())) { DataBuffer buffer = result.data().originalDataBuffer() == null ? result.data() : result.data().originalDataBuffer(); @@ -206,8 +162,7 @@ public class SynchronousFlowController implements FlowController { val pointData = allocator.getAllocationPoint(operand); val pointShape = allocator.getAllocationPoint(operand.shapeInfoDataBuffer()); - - pointData.acquireLock(); + Nd4j.getAffinityManager().ensureLocation(operand, AffinityManager.Location.DEVICE); if (pointData.getDeviceId() != cId && pointData.getDeviceId() >= 0 && (!CudaEnvironment.getInstance().getConfiguration().isCrossDeviceAccessAllowed() || !NativeOpsHolder.getInstance().getDeviceNativeOps().isP2PAvailable())) { DataBuffer buffer = operand.data().originalDataBuffer() == null ? operand.data() @@ -240,14 +195,12 @@ public class SynchronousFlowController implements FlowController { eventsProvider.storeEvent(result.getLastWriteEvent()); result.setLastWriteEvent(eventsProvider.getEvent()); result.getLastWriteEvent().register(context.getOldStream()); - result.releaseLock(); for (AllocationPoint operand : operands) { eventsProvider.storeEvent(operand.getLastReadEvent()); operand.setLastReadEvent(eventsProvider.getEvent()); operand.getLastReadEvent().register(context.getOldStream()); - operand.releaseLock(); } // context.syncOldStream(); } @@ -263,7 +216,6 @@ public class SynchronousFlowController implements FlowController { eventsProvider.storeEvent(pointOperand.getLastWriteEvent()); pointOperand.setLastWriteEvent(eventsProvider.getEvent()); pointOperand.getLastWriteEvent().register(context.getOldStream()); - pointOperand.releaseLock(); } } @@ -276,14 +228,12 @@ public class SynchronousFlowController implements FlowController { eventsProvider.storeEvent(point.getLastWriteEvent()); point.setLastWriteEvent(eventsProvider.getEvent()); point.getLastWriteEvent().register(context.getOldStream()); - point.releaseLock(); for (INDArray operand : operands) { if (operand == null || operand.isEmpty()) continue; val pointOperand = allocator.getAllocationPoint(operand); - pointOperand.releaseLock(); eventsProvider.storeEvent(pointOperand.getLastReadEvent()); pointOperand.setLastReadEvent(eventsProvider.getEvent()); pointOperand.getLastReadEvent().register(context.getOldStream()); @@ -295,7 +245,6 @@ public class SynchronousFlowController implements FlowController { val context = allocator.getDeviceContext(); if (result != null) { - result.acquireLock(); result.setCurrentContext(context); } @@ -303,7 +252,6 @@ public class SynchronousFlowController implements FlowController { if (operand == null) continue; - operand.acquireLock(); operand.setCurrentContext(context); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java index 9b8c1012c..f1cbf4958 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java @@ -16,6 +16,7 @@ package org.nd4j.jita.handler.impl; +import lombok.var; import org.nd4j.nativeblas.OpaqueLaunchContext; import org.nd4j.shade.guava.collect.HashBasedTable; import org.nd4j.shade.guava.collect.Table; @@ -44,9 +45,6 @@ import org.nd4j.jita.flow.FlowController; import org.nd4j.jita.flow.impl.GridFlowController; import org.nd4j.jita.handler.MemoryHandler; import org.nd4j.jita.memory.MemoryProvider; -import org.nd4j.jita.memory.impl.CudaCachingZeroProvider; -import org.nd4j.jita.memory.impl.CudaDirectProvider; -import org.nd4j.jita.memory.impl.CudaFullCachingProvider; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.memory.MemoryWorkspace; @@ -99,9 +97,6 @@ public class CudaZeroHandler implements MemoryHandler { private final AtomicBoolean wasInitialised = new AtomicBoolean(false); - @Getter - private final MemoryProvider memoryProvider; - private final FlowController flowController; private final AllocationStatus INITIAL_LOCATION; @@ -148,20 +143,6 @@ public class CudaZeroHandler implements MemoryHandler { throw new RuntimeException("Unknown ExecutionModel: [" + configuration.getExecutionModel() + "]"); } - switch (configuration.getAllocationModel()) { - case CACHE_ALL: - this.memoryProvider = new CudaFullCachingProvider(); - break; - case CACHE_HOST: - this.memoryProvider = new CudaCachingZeroProvider(); - break; - case DIRECT: - this.memoryProvider = new CudaDirectProvider(); - break; - default: - throw new RuntimeException("Unknown AllocationModel: [" + configuration.getAllocationModel() + "]"); - } - int numDevices = NativeOpsHolder.getInstance().getDeviceNativeOps().getAvailableDevices(); for (int i = 0; i < numDevices; i++) { deviceAllocations.add(new ConcurrentHashMap()); @@ -191,7 +172,7 @@ public class CudaZeroHandler implements MemoryHandler { int numBuckets = configuration.getNumberOfGcThreads(); long bucketId = RandomUtils.nextInt(0, numBuckets); - long reqMemory = AllocationUtils.getRequiredMemory(point.getShape()); + long reqMemory = point.getNumberOfBytes(); zeroUseCounter.addAndGet(reqMemory); @@ -221,130 +202,7 @@ public class CudaZeroHandler implements MemoryHandler { public PointersPair alloc(AllocationStatus targetMode, AllocationPoint point, AllocationShape shape, boolean initialize) { - long reqMemory = AllocationUtils.getRequiredMemory(shape); - val context = getCudaContext(); - switch (targetMode) { - case HOST: { - if (MemoryTracker.getInstance().getActiveHostAmount() + reqMemory >= configuration.getMaximumZeroAllocation()) { - - while (MemoryTracker.getInstance().getActiveHostAmount() + reqMemory >= configuration.getMaximumZeroAllocation()) { - - val before = MemoryTracker.getInstance().getActiveHostAmount(); - memoryProvider.purgeCache(); - Nd4j.getMemoryManager().invokeGc(); - val after = MemoryTracker.getInstance().getActiveHostAmount(); - - log.debug("[HOST] before: {}; after: {};", before, after); - - if (MemoryTracker.getInstance().getActiveHostAmount() + reqMemory >= configuration.getMaximumZeroAllocation()) { - try { - log.warn("No available [HOST] memory, sleeping for a while... Consider increasing -Xmx next time."); - log.debug("Currently used: [" + zeroUseCounter.get() + "], allocated objects: [" + zeroAllocations.get(0) + "]"); - - memoryProvider.purgeCache(); - Nd4j.getMemoryManager().invokeGc(); - Thread.sleep(1000); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - } - } - - PointersPair pair = memoryProvider.malloc(shape, point, targetMode); - - if (initialize) { - org.bytedeco.javacpp.Pointer.memset(pair.getHostPointer(), 0, reqMemory); - point.tickHostWrite(); - } - - - pickupHostAllocation(point); - - return pair; - } - case DEVICE: { - int deviceId = getDeviceId(); - - PointersPair returnPair = new PointersPair(); - PointersPair tmpPair = new PointersPair(); - - if (point.getPointers() == null) - point.setPointers(tmpPair); - - if (deviceMemoryTracker.reserveAllocationIfPossible(Thread.currentThread().getId(), deviceId, reqMemory)) { - point.setDeviceId(deviceId); - val pair = memoryProvider.malloc(shape, point, targetMode); - if (pair != null) { - returnPair.setDevicePointer(pair.getDevicePointer()); - - point.setAllocationStatus(AllocationStatus.DEVICE); - - if (point.getPointers() == null) - throw new RuntimeException("PointersPair can't be null"); - - point.getPointers().setDevicePointer(pair.getDevicePointer()); - - deviceAllocations.get(deviceId).put(point.getObjectId(), point.getObjectId()); - - - val p = point.getBucketId(); - - if (p != null) { - val m = zeroAllocations.get(point.getBucketId()); - - // m can be null, if that's point from workspace - just no bucketId for it - if (m != null) - m.remove(point.getObjectId()); - } - - deviceMemoryTracker.addToAllocation(Thread.currentThread().getId(), deviceId, reqMemory); - - if (!initialize) { - point.tickDeviceWrite(); - } else { - nativeOps.memsetAsync(pair.getDevicePointer(), 0, reqMemory, 0, context.getSpecialStream()); - context.getSpecialStream().synchronize(); - - point.tickDeviceWrite(); - } - } else { - log.warn("Out of [DEVICE] memory, host memory will be used instead: deviceId: [{}], requested bytes: [{}]; Approximate free bytes: {}; Real free bytes: {}", deviceId, reqMemory, MemoryTracker.getInstance().getApproximateFreeMemory(deviceId), MemoryTracker.getInstance().getPreciseFreeMemory(deviceId)); - log.info("Total allocated dev_0: {}", MemoryTracker.getInstance().getActiveMemory(0)); - log.info("Cached dev_0: {}", MemoryTracker.getInstance().getCachedAmount(0)); - log.info("Allocated dev_0: {}", MemoryTracker.getInstance().getAllocatedAmount(0)); - log.info("Workspace dev_0: {}", MemoryTracker.getInstance().getWorkspaceAllocatedAmount(0)); - //log.info("Total allocated dev_1: {}", MemoryTracker.getInstance().getActiveMemory(1)); - // if device memory allocation failed (aka returned NULL), keep using host memory instead - - returnPair.setDevicePointer(tmpPair.getHostPointer()); - - point.setAllocationStatus(AllocationStatus.HOST); - - Nd4j.getMemoryManager().invokeGc(); - try { - Thread.sleep(100); - } catch (Exception e) { - - } - } - } else { - log.warn("Hard limit on [DEVICE] memory hit, please consider tuning memory parameters, deviceId [{}]", - deviceId); - - Nd4j.getMemoryManager().invokeGc(); - try { - Thread.sleep(100); - } catch (InterruptedException e) { - // - } - } - - return returnPair; - } - default: - throw new IllegalStateException("Can't allocate memory on target [" + targetMode + "]"); - } + throw new UnsupportedOperationException(); } /** @@ -356,7 +214,7 @@ public class CudaZeroHandler implements MemoryHandler { */ @Override public boolean pingDeviceForFreeMemory(Integer deviceId, long requiredMemory) { - return memoryProvider.pingDeviceForFreeMemory(deviceId, requiredMemory); + return true; } /** @@ -371,47 +229,7 @@ public class CudaZeroHandler implements MemoryHandler { @Override public void relocate(AllocationStatus currentStatus, AllocationStatus targetStatus, AllocationPoint point, AllocationShape shape, CudaContext context) { - //log.info("RELOCATE CALLED: [" +currentStatus+ "] -> ["+targetStatus+"]"); - if (currentStatus == AllocationStatus.DEVICE && targetStatus == AllocationStatus.HOST) { - // DEVICE -> HOST - DataBuffer targetBuffer = point.getBuffer(); - if (targetBuffer == null) - throw new IllegalStateException("Target buffer is NULL!"); - - Pointer devicePointer = new CudaPointer(point.getPointers().getDevicePointer().address()); - - } else if (currentStatus == AllocationStatus.HOST && targetStatus == AllocationStatus.DEVICE) { - // HOST -> DEVICE - - - // TODO: this probably should be removed - if (point.isConstant()) { - //log.info("Skipping relocation for constant"); - return; - } - - if (point.getPointers().getDevicePointer() == null) { - throw new IllegalStateException("devicePointer is NULL!"); - } - - val profD = PerformanceTracker.getInstance().helperStartTransaction(); - - if (nativeOps.memcpyAsync(point.getPointers().getDevicePointer(), point.getPointers().getHostPointer(), - AllocationUtils.getRequiredMemory(shape), CudaConstants.cudaMemcpyHostToDevice, - context.getSpecialStream()) == 0) - throw new IllegalStateException("MemcpyAsync relocate H2D failed: [" + point.getHostPointer().address() - + "] -> [" + point.getDevicePointer().address() + "]"); - - flowController.commitTransfer(context.getSpecialStream()); - - PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), profD, point.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE); - - //context.syncOldStream(); - - } else - throw new UnsupportedOperationException("Can't relocate data in requested direction: [" + currentStatus - + "] -> [" + targetStatus + "]"); } /** @@ -440,11 +258,6 @@ public class CudaZeroHandler implements MemoryHandler { @Override @Deprecated public void copyforward(AllocationPoint point, AllocationShape shape) { - /* - Technically that's just a case for relocate, with source as HOST and target point.getAllocationStatus() - */ - log.info("copyforward() called on tp[" + point.getObjectId() + "], shape: " + point.getShape()); - //relocate(AllocationStatus.HOST, point.getAllocationStatus(), point, shape); throw new UnsupportedOperationException("Deprecated call"); } @@ -467,15 +280,7 @@ public class CudaZeroHandler implements MemoryHandler { */ @Override public void free(AllocationPoint point, AllocationStatus target) { - //if (point.getAllocationStatus() == AllocationStatus.DEVICE) - //deviceAllocations.get(point.getDeviceId()).remove(point.getObjectId()); - //zeroAllocations.get(point.getBucketId()).remove(point.getObjectId()); - if (point.getAllocationStatus() == AllocationStatus.DEVICE) - deviceMemoryTracker.subFromAllocation(Thread.currentThread().getId(), point.getDeviceId(), - AllocationUtils.getRequiredMemory(point.getShape())); - - memoryProvider.free(point); } /** @@ -525,7 +330,7 @@ public class CudaZeroHandler implements MemoryHandler { CudaContext tContext = null; if (dstBuffer.isConstant()) { - org.bytedeco.javacpp.Pointer dstPointer = new CudaPointer(point.getPointers().getHostPointer().address() + dstOffset, 0L); + org.bytedeco.javacpp.Pointer dstPointer = new CudaPointer(point.getHostPointer().address() + dstOffset, 0L); org.bytedeco.javacpp.Pointer srcPointerJ = new CudaPointer(srcPointer, length); val profD = PerformanceTracker.getInstance().helperStartTransaction(); @@ -534,14 +339,34 @@ public class CudaZeroHandler implements MemoryHandler { point.tickHostRead(); } else { + // if we're copying something into host memory, but we're on device - we need to provide exact copy to device as well + Pointer rDP = new CudaPointer(point.getDevicePointer().address() + dstOffset); + + if (tContext == null) + tContext = flowController.prepareAction(point); + + var prof = PerformanceTracker.getInstance().helperStartTransaction(); + + flowController.commitTransfer(tContext.getSpecialStream()); + + if (nativeOps.memcpyAsync(rDP, srcPointer, length, CudaConstants.cudaMemcpyHostToDevice, tContext.getSpecialStream()) == 0) + throw new IllegalStateException("MemcpyAsync H2D failed: [" + srcPointer.address() + "] -> [" + rDP.address() + "]"); + + flowController.commitTransfer(tContext.getSpecialStream()); + + PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), prof, point.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE); + + flowController.registerAction(tContext, point); + point.tickDeviceWrite(); + // we optionally copy to host memory - if (point.getPointers().getHostPointer() != null) { - Pointer dP = new CudaPointer((point.getPointers().getHostPointer().address()) + dstOffset); + if (point.getHostPointer() != null) { + Pointer dP = new CudaPointer((point.getHostPointer().address()) + dstOffset); CudaContext context = flowController.prepareAction(point); tContext = context; - val prof = PerformanceTracker.getInstance().helperStartTransaction(); + prof = PerformanceTracker.getInstance().helperStartTransaction(); if (nativeOps.memcpyAsync(dP, srcPointer, length, CudaConstants.cudaMemcpyHostToHost, context.getSpecialStream()) == 0) throw new IllegalStateException("MemcpyAsync H2H failed: [" + srcPointer.address() + "] -> [" + dP.address() + "]"); @@ -552,28 +377,10 @@ public class CudaZeroHandler implements MemoryHandler { if (point.getAllocationStatus() == AllocationStatus.HOST) flowController.registerAction(context, point); + + point.tickHostRead(); } } - - // if we're copying something into host memory, but we're on device - we need to provide exact copy to device as well - if (point.getAllocationStatus() == AllocationStatus.DEVICE) { - Pointer rDP = new CudaPointer(point.getPointers().getDevicePointer().address() + dstOffset); - - if (tContext == null) - tContext = flowController.prepareAction(point); - - val prof = PerformanceTracker.getInstance().helperStartTransaction(); - - if (nativeOps.memcpyAsync(rDP, srcPointer, length, CudaConstants.cudaMemcpyHostToDevice, tContext.getSpecialStream()) == 0) - throw new IllegalStateException("MemcpyAsync H2D failed: [" + srcPointer.address() + "] -> [" + rDP.address() + "]"); - - flowController.commitTransfer(tContext.getSpecialStream()); - - PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), prof, point.getNumberOfBytes(),MemcpyDirection.HOST_TO_DEVICE); - - flowController.registerAction(tContext, point); - point.tickDeviceWrite(); - } } @Override @@ -581,7 +388,7 @@ public class CudaZeroHandler implements MemoryHandler { CudaContext context) { AllocationPoint point = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint(); - Pointer dP = new CudaPointer((point.getPointers().getDevicePointer().address()) + dstOffset); + Pointer dP = new CudaPointer((point.getDevicePointer().address()) + dstOffset); if (nativeOps.memcpyAsync(dP, srcPointer, length, CudaConstants.cudaMemcpyDeviceToDevice, context.getOldStream()) == 0) throw new ND4JIllegalStateException("memcpyAsync failed"); @@ -604,7 +411,7 @@ public class CudaZeroHandler implements MemoryHandler { CudaContext context = getCudaContext(); AllocationPoint point = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint(); - Pointer dP = new CudaPointer((point.getPointers().getHostPointer().address()) + dstOffset); + Pointer dP = new CudaPointer((point.getHostPointer().address()) + dstOffset); val profH = PerformanceTracker.getInstance().helperStartTransaction(); @@ -614,7 +421,7 @@ public class CudaZeroHandler implements MemoryHandler { PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), profH, point.getNumberOfBytes(),MemcpyDirection.HOST_TO_HOST); if (point.getAllocationStatus() == AllocationStatus.DEVICE) { - Pointer rDP = new CudaPointer(point.getPointers().getDevicePointer().address() + dstOffset); + Pointer rDP = new CudaPointer(point.getDevicePointer().address() + dstOffset); val profD = PerformanceTracker.getInstance().helperStartTransaction(); @@ -717,23 +524,22 @@ public class CudaZeroHandler implements MemoryHandler { @Override public org.bytedeco.javacpp.Pointer getDevicePointer(DataBuffer buffer, CudaContext context) { // TODO: It would be awesome to get rid of typecasting here - //getCudaContext().syncOldStream(); AllocationPoint dstPoint = ((BaseCudaDataBuffer) buffer).getAllocationPoint(); // if that's device state, we probably might want to update device memory state if (dstPoint.getAllocationStatus() == AllocationStatus.DEVICE) { if (!dstPoint.isActualOnDeviceSide()) { - // log.info("Relocating to GPU"); - relocate(AllocationStatus.HOST, AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), context); + //relocate(AllocationStatus.HOST, AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), context); + throw new UnsupportedOperationException("Pew-pew"); } } - // we update memory use counter, to announce that it's somehow used on device - dstPoint.tickDeviceRead(); + if (dstPoint.getDevicePointer() == null) + return null; - // return pointer with offset if needed. length is specified for constructor compatibility purposes - val p = new CudaPointer(dstPoint.getPointers().getDevicePointer(), buffer.length(), - (buffer.offset() * buffer.getElementSize())); + + // return pointer. length is specified for constructor compatibility purposes. Offset is accounted at C++ side + val p = new CudaPointer(dstPoint.getDevicePointer(), buffer.length(), 0); if (OpProfiler.getInstance().getConfig().isCheckLocality()) NativeOpsHolder.getInstance().getDeviceNativeOps().tryPointer(context.getOldStream(), p, 1); @@ -749,10 +555,17 @@ public class CudaZeroHandler implements MemoryHandler { case SHORT: case UINT16: case HALF: + case BFLOAT16: return p.asShortPointer(); case UINT64: case LONG: return p.asLongPointer(); + case UTF8: + case UBYTE: + case BYTE: + return p.asBytePointer(); + case BOOL: + return p.asBooleanPointer(); default: return p; } @@ -769,17 +582,14 @@ public class CudaZeroHandler implements MemoryHandler { AllocationPoint dstPoint = ((BaseCudaDataBuffer) buffer).getAllocationPoint(); // return pointer with offset if needed. length is specified for constructor compatibility purposes - if (dstPoint.getPointers().getHostPointer() == null) { + if (dstPoint.getHostPointer() == null) { return null; } - //dstPoint.tickHostWrite(); - //dstPoint.tickHostRead(); - //log.info("Requesting host pointer for {}", buffer); - //getCudaContext().syncOldStream(); + synchronizeThreadDevice(Thread.currentThread().getId(), dstPoint.getDeviceId(), dstPoint); - CudaPointer p = new CudaPointer(dstPoint.getPointers().getHostPointer(), buffer.length(), - (buffer.offset() * buffer.getElementSize())); + CudaPointer p = new CudaPointer(dstPoint.getHostPointer(), buffer.length(), 0); + switch (buffer.dataType()) { case DOUBLE: return p.asDoublePointer(); @@ -805,6 +615,9 @@ public class CudaZeroHandler implements MemoryHandler { public synchronized void relocateObject(DataBuffer buffer) { AllocationPoint dstPoint = AtomicAllocator.getInstance().getAllocationPoint(buffer); + if (1 > 0) + throw new UnsupportedOperationException("Pew-pew"); + // we don't relocate non-DEVICE buffers (i.e HOST or CONSTANT) if (dstPoint.getAllocationStatus() != AllocationStatus.DEVICE) return; @@ -838,14 +651,14 @@ public class CudaZeroHandler implements MemoryHandler { // if we're out of workspace, we should mark our buffer as detached, so gc will pick it up eventually // host part is optional if (dstPoint.getHostPointer() != null) { - val pairH = alloc(AllocationStatus.HOST, dstPoint, dstPoint.getShape(), false); - dstPoint.getPointers().setHostPointer(pairH.getHostPointer()); + //val pairH = alloc(AllocationStatus.HOST, dstPoint, dstPoint.getShape(), false); + //dstPoint.getPointers().setHostPointer(pairH.getHostPointer()); } - val pairD = alloc(AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), false); - dstPoint.getPointers().setDevicePointer(pairD.getDevicePointer()); + //val pairD = alloc(AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), false); + //dstPoint.getPointers().setDevicePointer(pairD.getDevicePointer()); - //log.info("New host pointer: {}; Old host pointer: {}", dstPoint.getHostPointer().address(), ohPtr.address()); + ////log.info("New host pointer: {}; Old host pointer: {}", dstPoint.getHostPointer().address(), ohPtr.address()); CudaContext context = getCudaContext(); @@ -876,10 +689,10 @@ public class CudaZeroHandler implements MemoryHandler { Nd4j.getMemoryManager().memcpy(nBuffer, buffer); - dstPoint.getPointers().setDevicePointer(nBuffer.getAllocationPoint().getDevicePointer()); + //dstPoint.getPointers().setDevicePointer(nBuffer.getAllocationPoint().getDevicePointer()); if (dstPoint.getHostPointer() != null) { - dstPoint.getPointers().setHostPointer(nBuffer.getAllocationPoint().getHostPointer()); + // dstPoint.getPointers().setHostPointer(nBuffer.getAllocationPoint().getHostPointer()); } dstPoint.setDeviceId(deviceId); @@ -908,11 +721,10 @@ public class CudaZeroHandler implements MemoryHandler { context.syncSpecialStream(); } - memoryProvider.free(dstPoint); - deviceMemoryTracker.subFromAllocation(Thread.currentThread().getId(), dstPoint.getDeviceId(), AllocationUtils.getRequiredMemory(dstPoint.getShape())); + //deviceMemoryTracker.subFromAllocation(Thread.currentThread().getId(), dstPoint.getDeviceId(), AllocationUtils.getRequiredMemory(dstPoint.getShape())); // we replace original device pointer with new one - alloc(AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), false); + //alloc(AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), false); val profD = PerformanceTracker.getInstance().helperStartTransaction(); @@ -940,6 +752,9 @@ public class CudaZeroHandler implements MemoryHandler { public boolean promoteObject(DataBuffer buffer) { AllocationPoint dstPoint = AtomicAllocator.getInstance().getAllocationPoint(buffer); + if (1 > 0) + throw new UnsupportedOperationException("Pew-pew"); + if (dstPoint.getAllocationStatus() != AllocationStatus.HOST) return false; @@ -952,20 +767,19 @@ public class CudaZeroHandler implements MemoryHandler { Nd4j.getConstantHandler().moveToConstantSpace(buffer); } else { - PointersPair pair = memoryProvider.malloc(dstPoint.getShape(), dstPoint, AllocationStatus.DEVICE); + PointersPair pair = null; //memoryProvider.malloc(dstPoint.getShape(), dstPoint, AllocationStatus.DEVICE); if (pair != null) { Integer deviceId = getDeviceId(); // log.info("Promoting object to device: [{}]", deviceId); - dstPoint.getPointers().setDevicePointer(pair.getDevicePointer()); + //dstPoint.setDevicePointer(pair.getDevicePointer()); dstPoint.setAllocationStatus(AllocationStatus.DEVICE); deviceAllocations.get(deviceId).put(dstPoint.getObjectId(), dstPoint.getObjectId()); zeroAllocations.get(dstPoint.getBucketId()).remove(dstPoint.getObjectId()); - deviceMemoryTracker.addToAllocation(Thread.currentThread().getId(), deviceId, - AllocationUtils.getRequiredMemory(dstPoint.getShape())); + //deviceMemoryTracker.addToAllocation(Thread.currentThread().getId(), deviceId, AllocationUtils.getRequiredMemory(dstPoint.getShape())); dstPoint.tickHostWrite(); @@ -1103,7 +917,7 @@ public class CudaZeroHandler implements MemoryHandler { if (deviceAllocations.get(deviceId).containsKey(objectId)) throw new IllegalStateException("Can't happen ever"); - deviceMemoryTracker.subFromAllocation(threadId, deviceId, AllocationUtils.getRequiredMemory(point.getShape())); + //deviceMemoryTracker.subFromAllocation(threadId, deviceId, AllocationUtils.getRequiredMemory(point.getShape())); point.setAllocationStatus(AllocationStatus.HOST); @@ -1119,6 +933,9 @@ public class CudaZeroHandler implements MemoryHandler { */ @Override public void purgeZeroObject(Long bucketId, Long objectId, AllocationPoint point, boolean copyback) { + if (1 > 0) + throw new UnsupportedOperationException("Pew-pew"); + forget(point, AllocationStatus.HOST); flowController.waitTillReleased(point); @@ -1127,8 +944,8 @@ public class CudaZeroHandler implements MemoryHandler { if (point.getHostPointer() != null) { free(point, AllocationStatus.HOST); - long reqMem = AllocationUtils.getRequiredMemory(point.getShape()) * -1; - zeroUseCounter.addAndGet(reqMem); + //long reqMem = AllocationUtils.getRequiredMemory(point.getShape()) * -1; + //zeroUseCounter.addAndGet(reqMem); } point.setAllocationStatus(AllocationStatus.DEALLOCATED); @@ -1252,4 +1069,9 @@ public class CudaZeroHandler implements MemoryHandler { public FlowController getFlowController() { return flowController; } + + @Override + public MemoryProvider getMemoryProvider() { + return null; + } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/CudaMemoryManager.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/CudaMemoryManager.java index da36da6db..ad820c109 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/CudaMemoryManager.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/CudaMemoryManager.java @@ -147,7 +147,7 @@ public class CudaMemoryManager extends BasicMemoryManager { // Nd4j.getShapeInfoProvider().purgeCache(); // purge memory cache - AtomicAllocator.getInstance().getMemoryHandler().getMemoryProvider().purgeCache(); + //AtomicAllocator.getInstance().getMemoryHandler().getMemoryProvider().purgeCache(); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/impl/CudaCachingZeroProvider.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/impl/CudaCachingZeroProvider.java deleted file mode 100644 index 1ba6bf34a..000000000 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/impl/CudaCachingZeroProvider.java +++ /dev/null @@ -1,303 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.jita.memory.impl; - -import lombok.val; -import org.bytedeco.javacpp.Pointer; -import org.nd4j.jita.allocator.enums.AllocationStatus; -import org.nd4j.jita.allocator.impl.AllocationPoint; -import org.nd4j.jita.allocator.impl.AllocationShape; -import org.nd4j.jita.allocator.pointers.CudaPointer; -import org.nd4j.jita.allocator.pointers.PointersPair; -import org.nd4j.jita.allocator.utils.AllocationUtils; -import org.nd4j.jita.conf.Configuration; -import org.nd4j.jita.conf.CudaEnvironment; -import org.nd4j.jita.memory.MemoryProvider; -import org.slf4j.Logger; -import org.nd4j.linalg.factory.Nd4j; -import org.slf4j.LoggerFactory; - -import java.util.ArrayList; -import java.util.List; -import java.util.Queue; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.Semaphore; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; -import org.nd4j.jita.allocator.impl.MemoryTracker; - - -/** - * This is MemoryProvider implementation, that adds cache for memory reuse purposes. Only host memory is cached for future reuse. - * - * If some memory chunk gets released via allocator, it'll be probably saved for future reused within same JVM process. - * - * @author raver119@gmail.com - */ -public class CudaCachingZeroProvider extends CudaDirectProvider implements MemoryProvider { - private static Logger log = LoggerFactory.getLogger(CudaCachingZeroProvider.class); - - protected volatile ConcurrentHashMap zeroCache = new ConcurrentHashMap<>(); - - protected final AtomicLong cacheZeroHit = new AtomicLong(0); - protected final AtomicLong cacheZeroMiss = new AtomicLong(0); - - protected final AtomicLong cacheDeviceHit = new AtomicLong(0); - protected final AtomicLong cacheDeviceMiss = new AtomicLong(0); - - - - private final AtomicLong allocRequests = new AtomicLong(0); - - protected final AtomicLong zeroCachedAmount = new AtomicLong(0); - protected List deviceCachedAmount = new ArrayList<>(); - - - protected final Semaphore singleLock = new Semaphore(1); - - // we don't cache allocations greater then this value - //protected final long MAX_SINGLE_ALLOCATION = configuration.getMaximumHostCacheableLength(); - - // maximum cached size of memory - //protected final long MAX_CACHED_MEMORY = configuration.getMaximumHostCache(); - - // memory chunks below this threshold will be guaranteed regardless of number of cache entries - // that especially covers all possible variations of shapeInfoDataBuffers in all possible cases - protected final long FORCED_CACHE_THRESHOLD = 96; - - // number of preallocation entries for each yet-unknown shape - //protected final int PREALLOCATION_LIMIT = configuration.getPreallocationCalls(); - - public CudaCachingZeroProvider() { - - } - - /** - * This method provides PointersPair to memory chunk specified by AllocationShape - * - * PLEASE NOTE: This method can actually ignore malloc request, and give out previously cached free memory chunk with equal shape. - * - * @param shape shape of desired memory chunk - * @param point target AllocationPoint structure - * @param location either HOST or DEVICE - * @return - */ - @Override - public PointersPair malloc(AllocationShape shape, AllocationPoint point, AllocationStatus location) { - long reqMemory = AllocationUtils.getRequiredMemory(shape); - - if (location == AllocationStatus.HOST && reqMemory < CudaEnvironment.getInstance().getConfiguration().getMaximumHostCacheableLength()) { - - val cache = zeroCache.get(shape); - if (cache != null) { - val pointer = cache.poll(); - if (pointer != null) { - cacheZeroHit.incrementAndGet(); - - // since this memory chunk is going to be used now, remove it's amount from - zeroCachedAmount.addAndGet(-1 * reqMemory); - - val pair = new PointersPair(); - pair.setDevicePointer(new CudaPointer(pointer.address())); - pair.setHostPointer(new CudaPointer(pointer.address())); - - point.setAllocationStatus(AllocationStatus.HOST); - - MemoryTracker.getInstance().incrementAllocatedHostAmount(reqMemory); - MemoryTracker.getInstance().decrementCachedHostAmount(reqMemory); - - return pair; - } - } - cacheZeroMiss.incrementAndGet(); - - if (CudaEnvironment.getInstance().getConfiguration().isUsePreallocation() && zeroCachedAmount.get() < CudaEnvironment.getInstance().getConfiguration().getMaximumHostCache() / 10 - && reqMemory < 16 * 1024 * 1024L) { - val preallocator = new CachePreallocator(shape, location, CudaEnvironment.getInstance().getConfiguration().getPreallocationCalls()); - preallocator.start(); - } - - cacheZeroMiss.incrementAndGet(); - return super.malloc(shape, point, location); - } - - return super.malloc(shape, point, location); - } - - - - protected void ensureCacheHolder(AllocationShape shape) { - if (!zeroCache.containsKey(shape)) { - try { - singleLock.acquire(); - if (!zeroCache.containsKey(shape)) { - zeroCache.put(shape, new CacheHolder(shape, zeroCachedAmount)); - } - } catch (Exception e) { - throw new RuntimeException(e); - } finally { - singleLock.release(); - } - } - - } - - /** - * This method frees specific chunk of memory, described by AllocationPoint passed in. - * - * PLEASE NOTE: This method can actually ignore free, and keep released memory chunk for future reuse. - * - * @param point - */ - @Override - public void free(AllocationPoint point) { - if (point.getAllocationStatus() == AllocationStatus.DEVICE) { - super.free(point); - } else { - // if this point has no allocated chunk - step over it - if (point.getHostPointer() == null) - return; - - AllocationShape shape = point.getShape(); - long reqMemory = AllocationUtils.getRequiredMemory(shape); - - // we don't cache too big objects - if (reqMemory > CudaEnvironment.getInstance().getConfiguration().getMaximumHostCacheableLength() || zeroCachedAmount.get() >= CudaEnvironment.getInstance().getConfiguration().getMaximumHostCache()) { - super.free(point); - return; - } - - ensureCacheHolder(shape); - - /* - Now we should decide if this object can be cached or not - */ - CacheHolder cache = zeroCache.get(shape); - - // memory chunks < threshold will be cached no matter what - if (reqMemory <= FORCED_CACHE_THRESHOLD) { - Pointer.memset(point.getHostPointer(), 0, reqMemory); - cache.put(new CudaPointer(point.getHostPointer().address())); - } else { - long cacheEntries = cache.size(); - long cacheHeight = zeroCache.size(); - - // total memory allocated within this bucket - long cacheDepth = cacheEntries * reqMemory; - - Pointer.memset(point.getHostPointer(), 0, reqMemory); - cache.put(new CudaPointer(point.getHostPointer().address())); - } - - MemoryTracker.getInstance().decrementAllocatedHostAmount(reqMemory); - MemoryTracker.getInstance().incrementCachedHostAmount(reqMemory); - } - } - - private float getZeroCacheHitRatio() { - long totalHits = cacheZeroHit.get() + cacheZeroMiss.get(); - float cacheRatio = cacheZeroHit.get() * 100 / (float) totalHits; - return cacheRatio; - } - - private float getDeviceCacheHitRatio() { - long totalHits = cacheDeviceHit.get() + cacheDeviceMiss.get(); - float cacheRatio = cacheDeviceHit.get() * 100 / (float) totalHits; - return cacheRatio; - } - - @Deprecated - public void printCacheStats() { - log.debug("Cached host amount: " + zeroCachedAmount.get()); - log.debug("Cached device amount: " + deviceCachedAmount.get(0).get()); - log.debug("Total shapes in cache: " + zeroCache.size()); - log.debug("Current host hit ratio: " + getZeroCacheHitRatio()); - log.debug("Current device hit ratio: " + getDeviceCacheHitRatio()); - } - - protected class CacheHolder { - private Queue queue = new ConcurrentLinkedQueue<>(); - private volatile int counter = 0; - private long reqMem = 0; - private final AtomicLong allocCounter; - - public CacheHolder(AllocationShape shape, AtomicLong counter) { - this.reqMem = AllocationUtils.getRequiredMemory(shape); - this.allocCounter = counter; - } - - public synchronized int size() { - return counter; - } - - public synchronized Pointer poll() { - val pointer = queue.poll(); - if (pointer != null) - counter--; - - return pointer; - } - - public synchronized void put(Pointer pointer) { - allocCounter.addAndGet(reqMem); - counter++; - queue.add(pointer); - } - } - - protected class CachePreallocator extends Thread implements Runnable { - - private AllocationShape shape; - private AllocationStatus location; - private int target; - - public CachePreallocator(AllocationShape shape, AllocationStatus location, int numberOfEntries) { - this.shape = shape; - this.target = numberOfEntries; - this.location = location; - } - - @Override - public void run() { - ensureCacheHolder(shape); - - for (int i = 0; i < target; i++) { - val point = new AllocationPoint(); - - val pair = CudaCachingZeroProvider.super.malloc(shape, point, this.location); - if (this.location == AllocationStatus.HOST) { - Pointer pointer = new CudaPointer(pair.getHostPointer().address()); - CudaCachingZeroProvider.this.zeroCache.get(shape).put(pointer); - } - } - } - } - - @Override - public void purgeCache() { - for (AllocationShape shape : zeroCache.keySet()) { - Pointer ptr = null; - while ((ptr = zeroCache.get(shape).poll()) != null) { - freeHost(ptr); - MemoryTracker.getInstance().decrementCachedHostAmount(shape.getNumberOfBytes()); - } - } - - zeroCachedAmount.set(0); - } -} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/impl/CudaDirectProvider.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/impl/CudaDirectProvider.java deleted file mode 100644 index eba4d74d0..000000000 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/impl/CudaDirectProvider.java +++ /dev/null @@ -1,239 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.jita.memory.impl; - -import lombok.val; -import lombok.var; -import org.bytedeco.javacpp.Pointer; -import org.nd4j.jita.allocator.enums.AllocationStatus; -import org.nd4j.jita.allocator.impl.AllocationPoint; -import org.nd4j.jita.allocator.impl.AllocationShape; -import org.nd4j.jita.allocator.impl.AtomicAllocator; -import org.nd4j.jita.allocator.pointers.CudaPointer; -import org.nd4j.jita.allocator.pointers.PointersPair; -import org.nd4j.jita.allocator.utils.AllocationUtils; -import org.nd4j.jita.memory.MemoryProvider; -import org.nd4j.linalg.api.memory.AllocationsTracker; -import org.nd4j.linalg.api.memory.enums.AllocationKind; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.nativeblas.NativeOps; -import org.nd4j.nativeblas.NativeOpsHolder; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.nd4j.jita.allocator.impl.MemoryTracker; - -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicLong; - -/** - * @author raver119@gmail.com - */ -public class CudaDirectProvider implements MemoryProvider { - - protected static final long DEVICE_RESERVED_SPACE = 1024 * 1024 * 50L; - private static Logger log = LoggerFactory.getLogger(CudaDirectProvider.class); - protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); - - protected volatile ConcurrentHashMap validator = new ConcurrentHashMap<>(); - - - private AtomicLong emergencyCounter = new AtomicLong(0); - - /** - * This method provides PointersPair to memory chunk specified by AllocationShape - * - * @param shape shape of desired memory chunk - * @param point target AllocationPoint structure - * @param location either HOST or DEVICE - * @return - */ - @Override - public PointersPair malloc(AllocationShape shape, AllocationPoint point, AllocationStatus location) { - - //log.info("shape onCreate: {}, target: {}", shape, location); - - switch (location) { - case HOST: { - long reqMem = AllocationUtils.getRequiredMemory(shape); - - // FIXME: this is WRONG, and directly leads to memleak - if (reqMem < 1) - reqMem = 1; - - val pointer = nativeOps.mallocHost(reqMem, 0); - if (pointer == null) - throw new RuntimeException("Can't allocate [HOST] memory: " + reqMem + "; threadId: " - + Thread.currentThread().getId()); - - // log.info("Host allocation, Thread id: {}, ReqMem: {}, Pointer: {}", Thread.currentThread().getId(), reqMem, pointer != null ? pointer.address() : null); - - val hostPointer = new CudaPointer(pointer); - - val devicePointerInfo = new PointersPair(); - if (point.getPointers().getDevicePointer() == null) { - point.setAllocationStatus(AllocationStatus.HOST); - devicePointerInfo.setDevicePointer(new CudaPointer(hostPointer, reqMem)); - } else - devicePointerInfo.setDevicePointer(point.getDevicePointer()); - - devicePointerInfo.setHostPointer(new CudaPointer(hostPointer, reqMem)); - - point.setPointers(devicePointerInfo); - - MemoryTracker.getInstance().incrementAllocatedHostAmount(reqMem); - - return devicePointerInfo; - } - case DEVICE: { - // cudaMalloc call - val deviceId = AtomicAllocator.getInstance().getDeviceId(); - long reqMem = AllocationUtils.getRequiredMemory(shape); - - // FIXME: this is WRONG, and directly leads to memleak - if (reqMem < 1) - reqMem = 1; - - AllocationsTracker.getInstance().markAllocated(AllocationKind.GENERAL, deviceId, reqMem); - var pointer = nativeOps.mallocDevice(reqMem, deviceId, 0); - if (pointer == null) { - // try to purge stuff if we're low on memory - purgeCache(deviceId); - - // call for gc - Nd4j.getMemoryManager().invokeGc(); - - pointer = nativeOps.mallocDevice(reqMem, deviceId, 0); - if (pointer == null) - return null; - } - - val devicePointer = new CudaPointer(pointer); - - var devicePointerInfo = point.getPointers(); - if (devicePointerInfo == null) - devicePointerInfo = new PointersPair(); - devicePointerInfo.setDevicePointer(new CudaPointer(devicePointer, reqMem)); - - point.setAllocationStatus(AllocationStatus.DEVICE); - point.setDeviceId(deviceId); - MemoryTracker.getInstance().incrementAllocatedAmount(deviceId, reqMem); - return devicePointerInfo; - } - default: - throw new IllegalStateException("Unsupported location for malloc: [" + location + "]"); - } - } - - /** - * This method frees specific chunk of memory, described by AllocationPoint passed in - * - * @param point - */ - @Override - public void free(AllocationPoint point) { - switch (point.getAllocationStatus()) { - case HOST: { - // cudaFreeHost call here - long reqMem = AllocationUtils.getRequiredMemory(point.getShape()); - val nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); - - long result = nativeOps.freeHost(point.getPointers().getHostPointer()); - if (result == 0) { - throw new RuntimeException("Can't deallocate [HOST] memory..."); - } - - MemoryTracker.getInstance().decrementAllocatedHostAmount(reqMem); - } - break; - case DEVICE: { - if (point.isConstant()) - return; - - long reqMem = AllocationUtils.getRequiredMemory(point.getShape()); - - val nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); - AllocationsTracker.getInstance().markReleased(AllocationKind.GENERAL, point.getDeviceId(), reqMem); - - val pointers = point.getPointers(); - - long result = nativeOps.freeDevice(pointers.getDevicePointer(), 0); - if (result == 0) - throw new RuntimeException("Can't deallocate [DEVICE] memory..."); - - MemoryTracker.getInstance().decrementAllocatedAmount(point.getDeviceId(), reqMem); - } - break; - default: - throw new IllegalStateException("Can't free memory on target [" + point.getAllocationStatus() + "]"); - } - } - - /** - * This method checks specified device for specified amount of memory - * - * @param deviceId - * @param requiredMemory - * @return - */ - public boolean pingDeviceForFreeMemory(Integer deviceId, long requiredMemory) { - /* - long[] totalMem = new long[1]; - long[] freeMem = new long[1]; - - - JCuda.cudaMemGetInfo(freeMem, totalMem); - - long free = freeMem[0]; - long total = totalMem[0]; - long used = total - free; - - /* - We don't want to allocate memory if it's too close to the end of available ram. - */ - //if (configuration != null && used > total * configuration.getMaxDeviceMemoryUsed()) return false; - - /* - if (free + requiredMemory < total * 0.85) - return true; - else return false; - */ - long freeMem = nativeOps.getDeviceFreeMemory(-1); - if (freeMem - requiredMemory < DEVICE_RESERVED_SPACE) - return false; - else - return true; - } - - protected void freeHost(Pointer pointer) { - val nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); - nativeOps.freeHost(pointer); - } - - protected void freeDevice(Pointer pointer, int deviceId) { - val nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); - nativeOps.freeDevice(pointer, 0); - } - - protected void purgeCache(int deviceId) { - // - } - - @Override - public void purgeCache() { - // no-op - } -} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/impl/CudaFullCachingProvider.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/impl/CudaFullCachingProvider.java deleted file mode 100644 index 2157dfb56..000000000 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/impl/CudaFullCachingProvider.java +++ /dev/null @@ -1,220 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.jita.memory.impl; - -import lombok.val; -import org.bytedeco.javacpp.Pointer; -import org.nd4j.jita.allocator.enums.AllocationStatus; -import org.nd4j.jita.allocator.impl.AllocationPoint; -import org.nd4j.jita.allocator.impl.AllocationShape; -import org.nd4j.jita.allocator.impl.AtomicAllocator; -import org.nd4j.jita.allocator.impl.MemoryTracker; -import org.nd4j.jita.allocator.pointers.CudaPointer; -import org.nd4j.jita.allocator.pointers.PointersPair; -import org.nd4j.jita.allocator.utils.AllocationUtils; -import org.nd4j.jita.conf.CudaEnvironment; -import org.nd4j.linalg.factory.Nd4j; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.ArrayList; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicLong; - -/** - * This MemoryProvider implementation does caching for both host and device memory within predefined limits. - * - * @author raver119@gmail.com - */ -public class CudaFullCachingProvider extends CudaCachingZeroProvider { - - //protected final long MAX_GPU_ALLOCATION = configuration.getMaximumSingleDeviceAllocation(); - - //protected final long MAX_GPU_CACHE = configuration.getMaximumDeviceCache(); - - - protected volatile ConcurrentHashMap> deviceCache = - new ConcurrentHashMap<>(); - - - private static Logger log = LoggerFactory.getLogger(CudaFullCachingProvider.class); - - public CudaFullCachingProvider() { - - init(); - } - - public void init() { - int numDevices = Nd4j.getAffinityManager().getNumberOfDevices(); - - deviceCachedAmount = new ArrayList<>(); - - for (int i = 0; i < numDevices; i++) { - deviceCachedAmount.add(new AtomicLong(0)); - } - } - - /** - * This method provides PointersPair to memory chunk specified by AllocationShape - * - * PLEASE NOTE: This method can actually ignore malloc request, and give out previously cached free memory chunk with equal shape. - * - * @param shape shape of desired memory chunk - * @param point target AllocationPoint structure - * @param location either HOST or DEVICE - * @return - */ - @Override - public PointersPair malloc(AllocationShape shape, AllocationPoint point, AllocationStatus location) { - val reqMemory = AllocationUtils.getRequiredMemory(shape); - if (location == AllocationStatus.DEVICE && reqMemory < CudaEnvironment.getInstance().getConfiguration().getMaximumDeviceAllocation()) { - - - val deviceId = AtomicAllocator.getInstance().getDeviceId(); - ensureDeviceCacheHolder(deviceId, shape); - - val cache = deviceCache.get(deviceId).get(shape); - if (cache != null) { - val pointer = cache.poll(); - if (pointer != null) { - cacheDeviceHit.incrementAndGet(); - - deviceCachedAmount.get(deviceId).addAndGet(-reqMemory); - - val pair = new PointersPair(); - pair.setDevicePointer(pointer); - - point.setAllocationStatus(AllocationStatus.DEVICE); - point.setDeviceId(deviceId); - - - MemoryTracker.getInstance().incrementAllocatedAmount(deviceId, reqMemory); - MemoryTracker.getInstance().decrementCachedAmount(deviceId, reqMemory); - - return pair; - } - } - cacheDeviceMiss.incrementAndGet(); - return super.malloc(shape, point, location); - } - return super.malloc(shape, point, location); - } - - /** - * This method frees specific chunk of memory, described by AllocationPoint passed in - * - * PLEASE NOTE: This method can actually ignore free, and keep released memory chunk for future reuse. - * - * @param point - */ - @Override - public void free(AllocationPoint point) { - if (point.getAllocationStatus() == AllocationStatus.DEVICE) { - if (point.isConstant()) - return; - - val shape = point.getShape(); - val deviceId = point.getDeviceId(); - val address = point.getDevicePointer().address(); - val reqMemory = AllocationUtils.getRequiredMemory(shape); - // we don't cache too big objects - - if (reqMemory > CudaEnvironment.getInstance().getConfiguration().getMaximumDeviceCacheableLength() || deviceCachedAmount.get(deviceId).get() >= CudaEnvironment.getInstance().getConfiguration().getMaximumDeviceCache()) { - super.free(point); - return; - } - - ensureDeviceCacheHolder(deviceId, shape); - - val cache = deviceCache.get(deviceId).get(shape); - - if (point.getDeviceId() != deviceId) - throw new RuntimeException("deviceId changed!"); - - // memory chunks < threshold will be cached no matter what - if (reqMemory <= FORCED_CACHE_THRESHOLD) { - cache.put(new CudaPointer(point.getDevicePointer().address())); - MemoryTracker.getInstance().incrementCachedAmount(deviceId, reqMemory); - MemoryTracker.getInstance().decrementAllocatedAmount(deviceId, reqMemory); - return; - } else { - - cache.put(new CudaPointer(point.getDevicePointer().address())); - - MemoryTracker.getInstance().incrementCachedAmount(deviceId, reqMemory); - MemoryTracker.getInstance().decrementAllocatedAmount(deviceId, reqMemory); - return; - } - } - super.free(point); - } - - /** - * This method checks, if storage contains holder for specified shape - * - * @param deviceId - * @param shape - */ - protected void ensureDeviceCacheHolder(Integer deviceId, AllocationShape shape) { - if (!deviceCache.containsKey(deviceId)) { - try { - synchronized (this) { - if (!deviceCache.containsKey(deviceId)) { - deviceCache.put(deviceId, new ConcurrentHashMap()); - } - } - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - if (!deviceCache.get(deviceId).containsKey(shape)) { - try { - singleLock.acquire(); - - if (!deviceCache.get(deviceId).containsKey(shape)) { - deviceCache.get(deviceId).put(shape, new CacheHolder(shape, deviceCachedAmount.get(deviceId))); - } - } catch (Exception e) { - - } finally { - singleLock.release(); - } - } - } - - @Override - protected synchronized void purgeCache(int deviceId) { - for (AllocationShape shape : deviceCache.get(deviceId).keySet()) { - Pointer ptr = null; - while ((ptr = deviceCache.get(deviceId).get(shape).poll()) != null) { - freeDevice(ptr, deviceId); - MemoryTracker.getInstance().decrementCachedAmount(deviceId, shape.getNumberOfBytes()); - } - } - - deviceCachedAmount.get(deviceId).set(0); - } - - @Override - public synchronized void purgeCache() { - for (Integer device : deviceCache.keySet()) { - purgeCache(device); - } - super.purgeCache(); - } -} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArray.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArray.java index 79d87a01e..df44adb17 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArray.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArray.java @@ -17,34 +17,39 @@ package org.nd4j.linalg.jcublas; +import com.google.flatbuffers.FlatBufferBuilder; import lombok.extern.slf4j.Slf4j; import lombok.val; +import org.bytedeco.javacpp.BytePointer; +import org.nd4j.base.Preconditions; +import org.nd4j.graph.FlatArray; import org.nd4j.jita.allocator.enums.AllocationStatus; import org.nd4j.jita.allocator.enums.CudaConstants; import org.nd4j.jita.allocator.impl.AllocationPoint; import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.buffer.DataTypeEx; -import org.nd4j.linalg.api.buffer.FloatBuffer; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.BaseNDArray; import org.nd4j.linalg.api.ndarray.BaseNDArrayProxy; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.JvmShapeInfo; -import org.nd4j.linalg.api.ops.executioner.GridExecutioner; import org.nd4j.linalg.api.ops.performance.PerformanceTracker; +import org.nd4j.linalg.api.ops.util.PrintVariable; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.jcublas.buffer.CudaLongDataBuffer; +import org.nd4j.linalg.jcublas.buffer.CudaUtf8Buffer; import org.nd4j.linalg.jcublas.context.CudaContext; import org.nd4j.linalg.memory.MemcpyDirection; import org.nd4j.linalg.workspace.WorkspaceUtils; import org.nd4j.nativeblas.NativeOpsHolder; +import java.io.ByteArrayOutputStream; +import java.io.DataOutputStream; +import java.io.IOException; import java.util.List; -import java.util.concurrent.atomic.AtomicLong; /** * @@ -387,10 +392,6 @@ public class JCublasNDArray extends BaseNDArray { super(data, order); } - public JCublasNDArray(FloatBuffer floatBuffer, char order) { - super(floatBuffer, order); - } - public JCublasNDArray(DataBuffer buffer, int[] shape, int[] strides) { super(buffer, shape, strides); } @@ -574,26 +575,16 @@ public class JCublasNDArray extends BaseNDArray { MemcpyDirection direction = MemcpyDirection.HOST_TO_HOST; val prof = PerformanceTracker.getInstance().helperStartTransaction(); - if (dstPoint.getAllocationStatus() == AllocationStatus.DEVICE && srcPoint.getAllocationStatus() == AllocationStatus.DEVICE) { - // d2d copy + if (srcPoint.isActualOnDeviceSide()) { route = 1; NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(dstPoint.getDevicePointer(), srcPoint.getDevicePointer(), this.data.length() * this.data.getElementSize(), CudaConstants.cudaMemcpyDeviceToDevice, blocking ? context.getOldStream() : context.getSpecialStream()); dstPoint.tickDeviceWrite(); direction = MemcpyDirection.DEVICE_TO_DEVICE; - } else if (dstPoint.getAllocationStatus() == AllocationStatus.HOST && srcPoint.getAllocationStatus() == AllocationStatus.DEVICE) { - route = 2; - NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(dstPoint.getHostPointer(), srcPoint.getDevicePointer(), this.data.length() * this.data.getElementSize(), CudaConstants.cudaMemcpyDeviceToHost, blocking ? context.getOldStream() : context.getSpecialStream()); - dstPoint.tickHostWrite(); - direction = MemcpyDirection.DEVICE_TO_HOST; - } else if (dstPoint.getAllocationStatus() == AllocationStatus.DEVICE && srcPoint.getAllocationStatus() == AllocationStatus.HOST) { + } else { route = 3; NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(dstPoint.getDevicePointer(), srcPoint.getHostPointer(), this.data.length() * this.data.getElementSize(), CudaConstants.cudaMemcpyHostToDevice, blocking ? context.getOldStream() : context.getSpecialStream()); dstPoint.tickDeviceWrite(); direction = MemcpyDirection.HOST_TO_DEVICE; - } else { - route = 4; - NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(dstPoint.getHostPointer(), srcPoint.getHostPointer(), this.data.length() * this.data.getElementSize(), CudaConstants.cudaMemcpyHostToHost, blocking ? context.getOldStream() : context.getSpecialStream()); - dstPoint.tickHostWrite(); } @@ -650,30 +641,16 @@ public class JCublasNDArray extends BaseNDArray { Nd4j.getMemoryManager().setCurrentWorkspace(target); -// log.info("Leveraging..."); - INDArray copy = null; if (!this.isView()) { - //if (1 < 0) { Nd4j.getExecutioner().commit(); - DataBuffer buffer = Nd4j.createBuffer(this.length(), false); + val buffer = Nd4j.createBuffer(this.length(), false); - AllocationPoint pointDst = AtomicAllocator.getInstance().getAllocationPoint(buffer); - AllocationPoint pointSrc = AtomicAllocator.getInstance().getAllocationPoint(this.data); + val pointDst = AtomicAllocator.getInstance().getAllocationPoint(buffer); + val pointSrc = AtomicAllocator.getInstance().getAllocationPoint(this.data); - CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(pointDst, pointSrc); -/* - if (NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(pointDst.getDevicePointer(), 0, 1, 0, context.getOldStream()) == 0) - throw new ND4JIllegalStateException("memsetAsync 1 failed"); - - context.syncOldStream(); - - if (NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(pointSrc.getDevicePointer(), 0, 1, 0, context.getOldStream()) == 0) - throw new ND4JIllegalStateException("memsetAsync 2 failed"); - - context.syncOldStream(); -*/ + val context = AtomicAllocator.getInstance().getFlowController().prepareAction(pointDst, pointSrc); MemcpyDirection direction = MemcpyDirection.DEVICE_TO_DEVICE; val perfD = PerformanceTracker.getInstance().helperStartTransaction(); @@ -690,12 +667,11 @@ public class JCublasNDArray extends BaseNDArray { context.syncOldStream(); - PerformanceTracker.getInstance().helperRegisterTransaction(pointDst.getDeviceId(), perfD, pointSrc.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE); + PerformanceTracker.getInstance().helperRegisterTransaction(pointDst.getDeviceId(), perfD, pointSrc.getNumberOfBytes(), direction); copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer()); // tag buffer as valid on device side - pointDst.tickHostRead(); pointDst.tickDeviceWrite(); AtomicAllocator.getInstance().getFlowController().registerAction(context, pointDst, pointSrc); @@ -728,6 +704,7 @@ public class JCublasNDArray extends BaseNDArray { val pointDst = AtomicAllocator.getInstance().getAllocationPoint(buffer); val pointSrc = AtomicAllocator.getInstance().getAllocationPoint(this.data); + val context = AtomicAllocator.getInstance().getFlowController().prepareAction(pointDst, pointSrc); MemcpyDirection direction = MemcpyDirection.DEVICE_TO_DEVICE; @@ -764,6 +741,38 @@ public class JCublasNDArray extends BaseNDArray { return copy; } + protected int stringBuffer(FlatBufferBuilder builder, DataBuffer buffer) { + Preconditions.checkArgument(buffer.dataType() == DataType.UTF8, "This method can be called on UTF8 buffers only"); + try { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(bos); + + val numWords = this.length(); + val ub = (CudaUtf8Buffer) buffer; + // writing length first + val t = length(); + val ptr = (BytePointer) ub.pointer(); + + // now write all strings as bytes + for (int i = 0; i < ub.length(); i++) { + dos.writeByte(ptr.get(i)); + } + + val bytes = bos.toByteArray(); + return FlatArray.createBufferVector(builder, bytes); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public String getString(long index) { + if (!isS()) + throw new UnsupportedOperationException("This method is usable only on String dataType, but got [" + this.dataType() + "]"); + + return ((CudaUtf8Buffer) data).getString(index); + } + /* @Override public INDArray convertToHalfs() { diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java index 0bcb6e562..c529c4f7c 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java @@ -18,11 +18,9 @@ package org.nd4j.linalg.jcublas; import lombok.extern.slf4j.Slf4j; import lombok.val; -import lombok.var; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataTypeEx; -import org.nd4j.linalg.api.buffer.Utf8Buffer; import org.nd4j.linalg.api.memory.enums.MemoryKind; import org.nd4j.linalg.api.ops.custom.Flatten; import org.nd4j.linalg.api.ops.impl.shape.Concat; @@ -34,12 +32,10 @@ import org.nd4j.linalg.jcublas.buffer.*; import org.nd4j.linalg.memory.MemcpyDirection; import org.nd4j.linalg.primitives.Pair; import org.bytedeco.javacpp.*; -import org.bytedeco.javacpp.indexer.*; import org.nd4j.jita.allocator.enums.CudaConstants; import org.nd4j.jita.allocator.impl.AllocationPoint; import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.allocator.pointers.CudaPointer; -import org.nd4j.jita.allocator.utils.AllocationUtils; import org.nd4j.jita.conf.CudaEnvironment; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.memory.MemoryWorkspace; @@ -51,19 +47,12 @@ import org.nd4j.linalg.compression.CompressedDataBuffer; import org.nd4j.linalg.compression.CompressionDescriptor; import org.nd4j.linalg.compression.CompressionType; import org.nd4j.linalg.exception.ND4JIllegalStateException; -import org.nd4j.linalg.factory.BaseNDArrayFactory; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.jcublas.blas.*; import org.nd4j.linalg.jcublas.context.CudaContext; import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.nativeblas.*; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import java.io.File; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.charset.Charset; import java.util.*; /** @@ -216,7 +205,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { @Override public INDArray create(Collection strings, long[] shape, char order) { val pairShape = Nd4j.getShapeInfoProvider().createShapeInformation(shape, order, DataType.UTF8); - val buffer = new Utf8Buffer(strings); + val buffer = new CudaUtf8Buffer(strings); val list = new ArrayList(strings); return Nd4j.createArrayFromShapeBuffer(buffer, pairShape); } @@ -360,8 +349,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { @Override public INDArray concat(int dimension, INDArray... toConcat) { - if (Nd4j.getExecutioner() instanceof GridExecutioner) - ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); + Nd4j.getExecutioner().push(); return Nd4j.exec(new Concat(dimension, toConcat))[0]; } @@ -517,9 +505,9 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { AtomicAllocator allocator = AtomicAllocator.getInstance(); CudaContext context = allocator.getFlowController().prepareAction(ret, source); - Pointer x = AtomicAllocator.getInstance().getPointer(source, context); + val x = ((BaseCudaDataBuffer) source.data()).getOpaqueDataBuffer(); + val z = ((BaseCudaDataBuffer) ret.data()).getOpaqueDataBuffer(); Pointer xShape = AtomicAllocator.getInstance().getPointer(source.shapeInfoDataBuffer(), context); - Pointer z = AtomicAllocator.getInstance().getPointer(ret, context); Pointer zShape = AtomicAllocator.getInstance().getPointer(ret.shapeInfoDataBuffer(), context); PointerPointer extras = new PointerPointer(AddressRetriever.retrieveHostPointer(ret.shapeInfoDataBuffer()), @@ -545,14 +533,8 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { nativeOps.pullRows(extras, - null, - (LongPointer) source.shapeInfoDataBuffer().addressPointer(), - x, - (LongPointer) xShape, - null, - (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), - z, - (LongPointer) zShape, + x, (LongPointer) source.shapeInfoDataBuffer().addressPointer(), (LongPointer) xShape, + z, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), (LongPointer) zShape, indexes.length, (LongPointer) pIndex, (LongPointer) tadShapeInfo, @@ -601,7 +583,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { throw new ND4JIllegalStateException("All arrays should have equal length for averaging"); AllocationPoint point = allocator.getAllocationPoint(arrays[i]); - xPointers[i] = point.getPointers().getDevicePointer().address(); + xPointers[i] = point.getDevicePointer().address(); point.tickDeviceWrite(); } @@ -710,7 +692,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { throw new ND4JIllegalStateException("All arrays should have equal length for averaging"); AllocationPoint point = allocator.getAllocationPoint(arrays[i]); - xPointers[i] = point.getPointers().getDevicePointer().address(); + xPointers[i] = point.getDevicePointer().address(); point.tickDeviceWrite(); } @@ -1324,11 +1306,11 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { PointerPointer extraz = new PointerPointer(null, // not used context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer()); + val x = ((BaseCudaDataBuffer) tensor.data()).getOpaqueDataBuffer(); + + nativeOps.tear(extraz, - null, - (LongPointer) tensor.shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(tensor, context), - (LongPointer) AtomicAllocator.getInstance().getPointer(tensor.shapeInfoDataBuffer(), context), + x, (LongPointer) tensor.shapeInfoDataBuffer().addressPointer(), (LongPointer) AtomicAllocator.getInstance().getPointer(tensor.shapeInfoDataBuffer(), context), new PointerPointer(AtomicAllocator.getInstance().getPointer(tempX, context)), (LongPointer) AtomicAllocator.getInstance().getPointer(result[0].shapeInfoDataBuffer(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context), diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java index 9564fb15e..6c82cf1de 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java @@ -21,6 +21,7 @@ import lombok.NonNull; import lombok.val; import org.bytedeco.javacpp.*; import org.bytedeco.javacpp.indexer.*; +import org.nd4j.base.Preconditions; import org.nd4j.jita.allocator.enums.AllocationStatus; import org.nd4j.jita.allocator.enums.CudaConstants; import org.nd4j.jita.allocator.impl.AllocationPoint; @@ -38,6 +39,8 @@ import org.nd4j.linalg.api.memory.Deallocatable; import org.nd4j.linalg.api.memory.Deallocator; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.enums.MemoryKind; +import org.nd4j.linalg.api.memory.enums.MirroringPolicy; +import org.nd4j.linalg.api.memory.pointers.PagedPointer; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.performance.PerformanceTracker; import org.nd4j.linalg.exception.ND4JIllegalStateException; @@ -47,7 +50,9 @@ import org.nd4j.linalg.memory.MemcpyDirection; import org.nd4j.linalg.memory.abstracts.DummyWorkspace; import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.LongUtils; +import org.nd4j.nativeblas.NativeOps; import org.nd4j.nativeblas.NativeOpsHolder; +import org.nd4j.nativeblas.OpaqueDataBuffer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -74,6 +79,7 @@ import java.util.Collection; * @author raver119@gmail.com */ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCudaBuffer, Deallocatable { + protected OpaqueDataBuffer ptrDataBuffer; @Getter protected transient volatile AllocationPoint allocationPoint; @@ -88,10 +94,12 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda } + public OpaqueDataBuffer getOpaqueDataBuffer() { + return ptrDataBuffer; + } + + public BaseCudaDataBuffer(@NonNull Pointer pointer, @NonNull Pointer specialPointer, @NonNull Indexer indexer, long length) { - this.allocationPoint = AtomicAllocator.getInstance().pickExternalBuffer(this); - this.allocationPoint.setPointers(new PointersPair(specialPointer, pointer)); - this.trackingPoint = allocationPoint.getObjectId(); this.allocationMode = AllocationMode.MIXED_DATA_TYPES; this.indexer = indexer; @@ -102,6 +110,12 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda this.length = length; initTypeAndSize(); + + ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, this.type, false); + this.allocationPoint = new AllocationPoint(ptrDataBuffer, this.type.width() * length); + this.allocationPoint.setPointers(pointer, specialPointer, length); + + Nd4j.getDeallocatorService().pickObject(this); } /** @@ -114,10 +128,11 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda public BaseCudaDataBuffer(Pointer pointer, Indexer indexer, long length) { super(pointer, indexer, length); - //cuda specific bits - this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, dataType()), false); - this.trackingPoint = allocationPoint.getObjectId(); + // allocating interop buffer + this.ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(length, type, false); + //cuda specific bits + this.allocationPoint = new AllocationPoint(ptrDataBuffer, length * elementSize); Nd4j.getDeallocatorService().pickObject(this); // now we're @@ -222,71 +237,153 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda } public void lazyAllocateHostPointer() { - if (allocationPoint.getPointers().getHostPointer() == null) + if (length() == 0) + return; + + // java side might be unaware of native-side buffer allocation + if (this.indexer == null || this.pointer == null || this.pointer.address() == 0) { initHostPointerAndIndexer(); + } else if (allocationPoint.getHostPointer() != null && allocationPoint.getHostPointer().address() != this.pointer.address()) { + initHostPointerAndIndexer(); + } + } + + protected BaseCudaDataBuffer(ByteBuffer buffer, DataType dtype, long length, long offset) { + this(length, Nd4j.sizeOfDataType(dtype)); + + Pointer temp = null; + + switch (dataType()){ + case DOUBLE: + temp = new DoublePointer(buffer.asDoubleBuffer()); + break; + case FLOAT: + temp = new FloatPointer(buffer.asFloatBuffer()); + break; + case HALF: + temp = new ShortPointer(buffer.asShortBuffer()); + break; + case LONG: + temp = new LongPointer(buffer.asLongBuffer()); + break; + case INT: + temp = new IntPointer(buffer.asIntBuffer()); + break; + case SHORT: + temp = new ShortPointer(buffer.asShortBuffer()); + break; + case UBYTE: //Fall through + case BYTE: + temp = new BytePointer(buffer); + break; + case BOOL: + temp = new BooleanPointer(length()); + break; + case UTF8: + temp = new BytePointer(length()); + break; + case BFLOAT16: + temp = new ShortPointer(length()); + break; + case UINT16: + temp = new ShortPointer(length()); + break; + case UINT32: + temp = new IntPointer(length()); + break; + case UINT64: + temp = new LongPointer(length()); + break; + } + + // copy data to device + val stream = AtomicAllocator.getInstance().getDeviceContext().getSpecialStream(); + val ptr = ptrDataBuffer.specialBuffer(); + + if (offset > 0) + temp = new PagedPointer(temp.address() + offset * getElementSize()); + + NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(ptr, temp, length * Nd4j.sizeOfDataType(dtype), CudaConstants.cudaMemcpyHostToDevice, stream); + stream.synchronize(); + + // mark device buffer as updated + allocationPoint.tickDeviceWrite(); } protected void initHostPointerAndIndexer() { - if (allocationPoint.getPointers().getHostPointer() == null) { + if (length() == 0) + return; + + if (allocationPoint.getHostPointer() == null) { val location = allocationPoint.getAllocationStatus(); if (parentWorkspace == null) { - val ptr = AtomicAllocator.getInstance().getMemoryHandler().alloc(AllocationStatus.HOST, this.allocationPoint, this.allocationPoint.getShape(), false); - this.allocationPoint.getPointers().setHostPointer(ptr.getHostPointer()); + //log.info("dbAllocate step"); + // let cpp allocate primary buffer + NativeOpsHolder.getInstance().getDeviceNativeOps().dbAllocatePrimaryBuffer(ptrDataBuffer); } else { + //log.info("ws alloc step"); val ptr = parentWorkspace.alloc(this.length * this.elementSize, MemoryKind.HOST, this.dataType(), false); - this.allocationPoint.getPointers().setHostPointer(ptr); + ptrDataBuffer.setPrimaryBuffer(ptr, this.length); } this.allocationPoint.setAllocationStatus(location); this.allocationPoint.tickDeviceWrite(); } + val hostPointer = allocationPoint.getHostPointer(); + + assert hostPointer != null; + switch (dataType()) { case DOUBLE: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asDoublePointer(); + this.pointer = new CudaPointer(hostPointer, length, 0).asDoublePointer(); indexer = DoubleIndexer.create((DoublePointer) pointer); break; case FLOAT: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asFloatPointer(); + this.pointer = new CudaPointer(hostPointer, length, 0).asFloatPointer(); indexer = FloatIndexer.create((FloatPointer) pointer); break; case UINT32: case INT: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asIntPointer(); + this.pointer = new CudaPointer(hostPointer, length, 0).asIntPointer(); indexer = IntIndexer.create((IntPointer) pointer); break; case BFLOAT16: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asShortPointer(); + this.pointer = new CudaPointer(hostPointer, length, 0).asShortPointer(); indexer = Bfloat16Indexer.create((ShortPointer) pointer); break; case HALF: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asShortPointer(); + this.pointer = new CudaPointer(hostPointer, length, 0).asShortPointer(); indexer = HalfIndexer.create((ShortPointer) pointer); break; case UINT64: case LONG: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asLongPointer(); + this.pointer = new CudaPointer(hostPointer, length, 0).asLongPointer(); indexer = LongIndexer.create((LongPointer) pointer); break; case UINT16: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asShortPointer(); + this.pointer = new CudaPointer(hostPointer, length, 0).asShortPointer(); indexer = UShortIndexer.create((ShortPointer) pointer); break; case SHORT: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asShortPointer(); + this.pointer = new CudaPointer(hostPointer, length, 0).asShortPointer(); indexer = ShortIndexer.create((ShortPointer) pointer); break; case UBYTE: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asBytePointer(); + this.pointer = new CudaPointer(hostPointer, length, 0).asBytePointer(); indexer = UByteIndexer.create((BytePointer) pointer); break; case BYTE: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asBytePointer(); + this.pointer = new CudaPointer(hostPointer, length, 0).asBytePointer(); indexer = ByteIndexer.create((BytePointer) pointer); break; case BOOL: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asBooleanPointer(); + this.pointer = new CudaPointer(hostPointer, length, 0).asBooleanPointer(); indexer = BooleanIndexer.create((BooleanPointer) pointer); break; + case UTF8: + this.pointer = new CudaPointer(hostPointer, length, 0).asBytePointer(); + indexer = ByteIndexer.create((BytePointer) pointer); + break; default: throw new UnsupportedOperationException(); } @@ -294,21 +391,25 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda protected void initPointers(long length, int elementSize, boolean initialize) { this.allocationMode = AllocationMode.MIXED_DATA_TYPES; - this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, dataType()), initialize); this.length = length; - //allocationPoint.attachBuffer(this); this.elementSize = (byte) elementSize; - this.trackingPoint = allocationPoint.getObjectId(); + this.offset = 0; this.originalOffset = 0; + // we allocate native DataBuffer AND it will contain our device pointer + ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(length, type, false); + this.allocationPoint = new AllocationPoint(ptrDataBuffer, length * type.width()); + + if (initialize) { + val ctx = AtomicAllocator.getInstance().getDeviceContext(); + val devicePtr = allocationPoint.getDevicePointer(); + NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(devicePtr, 0, length * elementSize, 0, ctx.getSpecialStream()); + ctx.getSpecialStream().synchronize(); + } + + // let deallocator pick up this object Nd4j.getDeallocatorService().pickObject(this); - - // if only host - if (allocationPoint.getPointers().getHostPointer() == null) - return; - - initHostPointerAndIndexer(); } public BaseCudaDataBuffer(long length, int elementSize, boolean initialize) { @@ -323,72 +424,45 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda this.attached = true; this.parentWorkspace = workspace; - this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, this.elementSize, dataType()), initialize); this.length = length; - this.trackingPoint = allocationPoint.getObjectId(); this.offset = 0; this.originalOffset = 0; + // allocating empty databuffer + ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, type, false); + + if (workspace.getWorkspaceConfiguration().getPolicyMirroring() == MirroringPolicy.FULL) { + val devicePtr = workspace.alloc(length * elementSize, MemoryKind.DEVICE, type, initialize); + + // allocate from workspace, and pass it to native DataBuffer + ptrDataBuffer.setSpecialBuffer(devicePtr, this.length); + + if (initialize) { + val ctx = AtomicAllocator.getInstance().getDeviceContext(); + NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(devicePtr, 0, length * elementSize, 0, ctx.getSpecialStream()); + ctx.getSpecialStream().synchronize(); + } + } else { + // we can register this pointer as device, because it's pinned memory + val devicePtr = workspace.alloc(length * elementSize, MemoryKind.HOST, type, initialize); + ptrDataBuffer.setSpecialBuffer(devicePtr, this.length); + + if (initialize) { + val ctx = AtomicAllocator.getInstance().getDeviceContext(); + NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(devicePtr, 0, length * elementSize, 0, ctx.getSpecialStream()); + ctx.getSpecialStream().synchronize(); + } + } + + this.allocationPoint = new AllocationPoint(ptrDataBuffer, elementSize * length); + + // registering for deallocation Nd4j.getDeallocatorService().pickObject(this); workspaceGenerationId = workspace.getGenerationId(); this.attached = true; this.parentWorkspace = workspace; - - if (allocationPoint.getHostPointer() == null) - return; - - switch (dataType()) { - case DOUBLE: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asDoublePointer(); - indexer = DoubleIndexer.create((DoublePointer) pointer); - break; - case FLOAT: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asFloatPointer(); - indexer = FloatIndexer.create((FloatPointer) pointer); - break; - case UINT32: - case INT: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asIntPointer(); - indexer = IntIndexer.create((IntPointer) pointer); - break; - case BFLOAT16: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asShortPointer(); - indexer = Bfloat16Indexer.create((ShortPointer) pointer); - break; - case HALF: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asShortPointer(); - indexer = HalfIndexer.create((ShortPointer) pointer); - break; - case UINT64: - case LONG: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asLongPointer(); - indexer = LongIndexer.create((LongPointer) pointer); - break; - case BOOL: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asBooleanPointer(); - indexer = BooleanIndexer.create((BooleanPointer) pointer); - break; - case UINT16: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asShortPointer(); - indexer = UShortIndexer.create((ShortPointer) pointer); - break; - case SHORT: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asShortPointer(); - indexer = ShortIndexer.create((ShortPointer) pointer); - break; - case BYTE: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asBytePointer(); - indexer = ByteIndexer.create((BytePointer) pointer); - break; - case UBYTE: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asBytePointer(); - indexer = UByteIndexer.create((BytePointer) pointer); - break; - default: - throw new UnsupportedOperationException("Unknown data type: " + dataType()); - } } @Override @@ -427,60 +501,71 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda this.length = length; this.offset = offset; this.originalOffset = offset; - this.trackingPoint = underlyingBuffer.getTrackingPoint(); this.elementSize = (byte) underlyingBuffer.getElementSize(); - this.allocationPoint = ((BaseCudaDataBuffer) underlyingBuffer).allocationPoint; // in case of view creation, we initialize underlying buffer regardless of anything - ((BaseCudaDataBuffer) underlyingBuffer).lazyAllocateHostPointer();; + ((BaseCudaDataBuffer) underlyingBuffer).lazyAllocateHostPointer(); + + // we're creating view of the native DataBuffer + ptrDataBuffer = ((BaseCudaDataBuffer) underlyingBuffer).ptrDataBuffer.createView(length * underlyingBuffer.getElementSize(), offset * underlyingBuffer.getElementSize()); + this.allocationPoint = new AllocationPoint(ptrDataBuffer, length); + val hostPointer = allocationPoint.getHostPointer(); + + Nd4j.getDeallocatorService().pickObject(this); switch (underlyingBuffer.dataType()) { case DOUBLE: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), originalBuffer.length()).asDoublePointer(); + this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asDoublePointer(); indexer = DoubleIndexer.create((DoublePointer) pointer); break; case FLOAT: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), originalBuffer.length()).asFloatPointer(); + this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asFloatPointer(); indexer = FloatIndexer.create((FloatPointer) pointer); break; case UINT32: case INT: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), originalBuffer.length()).asIntPointer(); + this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asIntPointer(); indexer = IntIndexer.create((IntPointer) pointer); break; case BFLOAT16: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), originalBuffer.length()).asShortPointer(); + this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asShortPointer(); indexer = Bfloat16Indexer.create((ShortPointer) pointer); break; case HALF: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), originalBuffer.length()).asShortPointer(); + this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asShortPointer(); indexer = HalfIndexer.create((ShortPointer) pointer); break; case UINT64: case LONG: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), originalBuffer.length()).asLongPointer(); + this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asLongPointer(); indexer = LongIndexer.create((LongPointer) pointer); break; case UINT16: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), originalBuffer.length()).asShortPointer(); + this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asShortPointer(); indexer = UShortIndexer.create((ShortPointer) pointer); break; case SHORT: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), originalBuffer.length()).asShortPointer(); + this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asShortPointer(); indexer = ShortIndexer.create((ShortPointer) pointer); break; case BOOL: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), originalBuffer.length()).asBooleanPointer(); + this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asBooleanPointer(); indexer = BooleanIndexer.create((BooleanPointer) pointer); break; case BYTE: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), originalBuffer.length()).asBytePointer(); + this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asBytePointer(); indexer = ByteIndexer.create((BytePointer) pointer); break; case UBYTE: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), originalBuffer.length()).asBytePointer(); + this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asBytePointer(); indexer = UByteIndexer.create((BytePointer) pointer); break; + case UTF8: + Preconditions.checkArgument(offset == 0, "String array can't be a view"); + + this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asBytePointer(); + indexer = ByteIndexer.create((BytePointer) pointer); + break; default: throw new UnsupportedOperationException(); } @@ -522,23 +607,6 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda set(data, data.length, 0, 0); } - public BaseCudaDataBuffer(byte[] data, long length, DataType type) { - this(ByteBuffer.wrap(data), length, type); - } - - public BaseCudaDataBuffer(ByteBuffer buffer, long length, DataType type) { - //super(buffer,length); - this(buffer, length, 0, type); - } - - public BaseCudaDataBuffer(ByteBuffer buffer, long length, long offset, DataType type) { - //super(buffer, length, offset); - this(length, Nd4j.sizeOfDataType(type), offset); - - Pointer srcPtr = new CudaPointer(new Pointer(buffer.order(ByteOrder.nativeOrder()))); - - allocator.memcpyAsync(this, srcPtr, length * elementSize, offset * elementSize); - } /** * This method always returns host pointer @@ -550,12 +618,12 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda if (released) throw new IllegalStateException("You can't use DataBuffer once it was released"); - return allocationPoint.getPointers().getHostPointer().address(); + return allocationPoint.getHostPointer().address(); } @Override public long platformAddress() { - return allocationPoint.getPointers().getDevicePointer().address(); + return allocationPoint.getDevicePointer().address(); } @Override @@ -585,7 +653,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda switch (dataType()) { case BOOL: { val pointer = new BytePointer(ArrayUtil.toBytes(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -595,7 +663,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case BYTE: { val pointer = new BytePointer(ArrayUtil.toBytes(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -611,7 +679,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case SHORT: { val pointer = new ShortPointer(ArrayUtil.toShorts(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -621,7 +689,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case INT: { val pointer = new IntPointer(data); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -631,7 +699,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case LONG: { val pointer = new LongPointer(LongUtils.toLongs(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -641,7 +709,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case HALF: { val pointer = new ShortPointer(ArrayUtil.toHalfs(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -651,7 +719,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case FLOAT: { val pointer = new FloatPointer(ArrayUtil.toFloats(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -661,7 +729,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case DOUBLE: { val pointer = new DoublePointer(ArrayUtil.toDouble(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -680,7 +748,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda switch (dataType()) { case BOOL: { val pointer = new BytePointer(ArrayUtil.toBytes(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -690,7 +758,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case BYTE: { val pointer = new BytePointer(ArrayUtil.toBytes(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -709,7 +777,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda data = ArrayUtil.cutBelowZero(data); case SHORT: { val pointer = new ShortPointer(ArrayUtil.toShorts(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -721,7 +789,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda data = ArrayUtil.cutBelowZero(data); case INT: { val pointer = new IntPointer(ArrayUtil.toInts(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -733,7 +801,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda data = ArrayUtil.cutBelowZero(data); case LONG: { val pointer = new LongPointer(data); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -743,7 +811,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case BFLOAT16: { val pointer = new ShortPointer(ArrayUtil.toBfloats(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -753,7 +821,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case HALF: { val pointer = new ShortPointer(ArrayUtil.toHalfs(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -763,7 +831,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case FLOAT: { val pointer = new FloatPointer(ArrayUtil.toFloats(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -773,7 +841,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case DOUBLE: { val pointer = new DoublePointer(ArrayUtil.toDouble(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); // we're keeping pointer reference for JVM @@ -799,7 +867,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda switch (dataType()) { case BOOL: { val pointer = new BytePointer(ArrayUtil.toBytes(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -809,7 +877,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case BYTE: { val pointer = new BytePointer(ArrayUtil.toBytes(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -825,7 +893,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case SHORT: { val pointer = new ShortPointer(ArrayUtil.toShorts(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -835,7 +903,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case INT: { val pointer = new IntPointer(ArrayUtil.toInts(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -845,7 +913,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case LONG: { val pointer = new LongPointer(ArrayUtil.toLongArray(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -855,7 +923,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case HALF: { val pointer = new ShortPointer(ArrayUtil.toHalfs(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -865,7 +933,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case FLOAT: { val pointer = new FloatPointer(data); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -875,7 +943,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case DOUBLE: { DoublePointer pointer = new DoublePointer(ArrayUtil.toDoubles(data)); - Pointer srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + Pointer srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -901,7 +969,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda switch (dataType()) { case BOOL: { val pointer = new BytePointer(ArrayUtil.toBytes(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -911,7 +979,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case BYTE: { val pointer = new BytePointer(ArrayUtil.toBytes(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -927,7 +995,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case SHORT: { val pointer = new ShortPointer(ArrayUtil.toShorts(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -937,7 +1005,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case INT: { val pointer = new IntPointer(ArrayUtil.toInts(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -947,7 +1015,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case LONG: { val pointer = new LongPointer(ArrayUtil.toLongs(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -957,7 +1025,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case HALF: { val pointer = new ShortPointer(ArrayUtil.toHalfs(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -967,7 +1035,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case FLOAT: { val pointer = new FloatPointer(ArrayUtil.toFloats(data)); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -977,7 +1045,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda break; case DOUBLE: { val pointer = new DoublePointer(data); - val srcPtr = new CudaPointer(pointer.address() + (dstOffset * elementSize)); + val srcPtr = new CudaPointer(pointer.address() + (srcOffset * elementSize)); allocator.memcpyAsync(this, srcPtr, length * elementSize, dstOffset * elementSize); @@ -1252,7 +1320,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda @Override public boolean sameUnderlyingData(DataBuffer buffer) { - return buffer.getTrackingPoint() == getTrackingPoint(); + return ptrDataBuffer.address() == ((BaseCudaDataBuffer) buffer).ptrDataBuffer.address(); } /** @@ -1345,54 +1413,54 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda this.elementSize = (byte) Nd4j.sizeOfDataType(t); this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, t), false); - this.trackingPoint = allocationPoint.getObjectId(); + this.type = t; Nd4j.getDeallocatorService().pickObject(this); switch (type) { case DOUBLE: { - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length).asDoublePointer(); + this.pointer = new CudaPointer(allocationPoint.getHostPointer(), length).asDoublePointer(); indexer = DoubleIndexer.create((DoublePointer) pointer); } break; case FLOAT: { - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length).asFloatPointer(); + this.pointer = new CudaPointer(allocationPoint.getHostPointer(), length).asFloatPointer(); indexer = FloatIndexer.create((FloatPointer) pointer); } break; case HALF: { - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length).asShortPointer(); + this.pointer = new CudaPointer(allocationPoint.getHostPointer(), length).asShortPointer(); indexer = HalfIndexer.create((ShortPointer) pointer); } break; case LONG: { - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length).asLongPointer(); + this.pointer = new CudaPointer(allocationPoint.getHostPointer(), length).asLongPointer(); indexer = LongIndexer.create((LongPointer) pointer); } break; case INT: { - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length).asIntPointer(); + this.pointer = new CudaPointer(allocationPoint.getHostPointer(), length).asIntPointer(); indexer = IntIndexer.create((IntPointer) pointer); } break; case SHORT: { - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length).asShortPointer(); + this.pointer = new CudaPointer(allocationPoint.getHostPointer(), length).asShortPointer(); indexer = ShortIndexer.create((ShortPointer) pointer); } break; case UBYTE: { - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length).asBytePointer(); + this.pointer = new CudaPointer(allocationPoint.getHostPointer(), length).asBytePointer(); indexer = UByteIndexer.create((BytePointer) pointer); } break; case BYTE: { - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length).asBytePointer(); + this.pointer = new CudaPointer(allocationPoint.getHostPointer(), length).asBytePointer(); indexer = ByteIndexer.create((BytePointer) pointer); } break; case BOOL: { - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length).asBooleanPointer(); + this.pointer = new CudaPointer(allocationPoint.getHostPointer(), length).asBooleanPointer(); indexer = BooleanIndexer.create((BooleanPointer) pointer); } break; @@ -1514,53 +1582,181 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda return super.getInt(ix); } + public void actualizePointerAndIndexer() { + val cptr = ptrDataBuffer.primaryBuffer(); + + // skip update if pointers are equal + if (cptr != null && pointer != null && cptr.address() == pointer.address()) + return; + + val t = dataType(); + if (t == DataType.BOOL) { + pointer = new PagedPointer(cptr, length).asBoolPointer(); + setIndexer(BooleanIndexer.create((BooleanPointer) pointer)); + } else if (t == DataType.UBYTE) { + pointer = new PagedPointer(cptr, length).asBytePointer(); + setIndexer(UByteIndexer.create((BytePointer) pointer)); + } else if (t == DataType.BYTE) { + pointer = new PagedPointer(cptr, length).asBytePointer(); + setIndexer(ByteIndexer.create((BytePointer) pointer)); + } else if (t == DataType.UINT16) { + pointer = new PagedPointer(cptr, length).asShortPointer(); + setIndexer(UShortIndexer.create((ShortPointer) pointer)); + } else if (t == DataType.SHORT) { + pointer = new PagedPointer(cptr, length).asShortPointer(); + setIndexer(ShortIndexer.create((ShortPointer) pointer)); + } else if (t == DataType.UINT32) { + pointer = new PagedPointer(cptr, length).asIntPointer(); + setIndexer(IntIndexer.create((IntPointer) pointer)); + } else if (t == DataType.INT) { + pointer = new PagedPointer(cptr, length).asIntPointer(); + setIndexer(IntIndexer.create((IntPointer) pointer)); + } else if (t == DataType.UINT64) { + pointer = new PagedPointer(cptr, length).asLongPointer(); + setIndexer(LongIndexer.create((LongPointer) pointer)); + } else if (t == DataType.LONG) { + pointer = new PagedPointer(cptr, length).asLongPointer(); + setIndexer(LongIndexer.create((LongPointer) pointer)); + } else if (t == DataType.BFLOAT16) { + pointer = new PagedPointer(cptr, length).asShortPointer(); + setIndexer(Bfloat16Indexer.create((ShortPointer) pointer)); + } else if (t == DataType.HALF) { + pointer = new PagedPointer(cptr, length).asShortPointer(); + setIndexer(HalfIndexer.create((ShortPointer) pointer)); + } else if (t == DataType.FLOAT) { + pointer = new PagedPointer(cptr, length).asFloatPointer(); + setIndexer(FloatIndexer.create((FloatPointer) pointer)); + } else if (t == DataType.DOUBLE) { + pointer = new PagedPointer(cptr, length).asDoublePointer(); + setIndexer(DoubleIndexer.create((DoublePointer) pointer)); + } else if (t == DataType.UTF8) { + pointer = new PagedPointer(cptr, length()).asBytePointer(); + setIndexer(ByteIndexer.create((BytePointer) pointer)); + } else + throw new IllegalArgumentException("Unknown datatype: " + dataType()); + } + @Override public DataBuffer reallocate(long length) { + val oldHostPointer = this.ptrDataBuffer.primaryBuffer(); + val oldDevicePointer = this.ptrDataBuffer.specialBuffer(); - // we want to be sure this array isn't used anywhere RIGHT AT THIS MOMENT - Nd4j.getExecutioner().commit(); + if (isAttached()) { + val capacity = length * getElementSize(); + + if (oldDevicePointer != null && oldDevicePointer.address() != 0) { + val nPtr = getParentWorkspace().alloc(capacity, MemoryKind.DEVICE, dataType(), false); + NativeOpsHolder.getInstance().getDeviceNativeOps().memcpySync(nPtr, oldDevicePointer, length * getElementSize(), 3, null); + this.ptrDataBuffer.setPrimaryBuffer(nPtr, length); + + allocationPoint.tickDeviceRead(); + } + + if (oldHostPointer != null && oldHostPointer.address() != 0) { + val nPtr = getParentWorkspace().alloc(capacity, MemoryKind.HOST, dataType(), false); + Pointer.memcpy(nPtr, oldHostPointer, this.length() * getElementSize()); + this.ptrDataBuffer.setPrimaryBuffer(nPtr, length); + + allocationPoint.tickHostRead(); + + switch (dataType()) { + case BOOL: + pointer = nPtr.asBoolPointer(); + indexer = BooleanIndexer.create((BooleanPointer) pointer); + break; + case UTF8: + case BYTE: + case UBYTE: + pointer = nPtr.asBytePointer(); + indexer = ByteIndexer.create((BytePointer) pointer); + break; + case UINT16: + case SHORT: + pointer = nPtr.asShortPointer(); + indexer = ShortIndexer.create((ShortPointer) pointer); + break; + case UINT32: + case INT: + pointer = nPtr.asIntPointer(); + indexer = IntIndexer.create((IntPointer) pointer); + break; + case DOUBLE: + pointer = nPtr.asDoublePointer(); + indexer = DoubleIndexer.create((DoublePointer) pointer); + break; + case FLOAT: + pointer = nPtr.asFloatPointer(); + indexer = FloatIndexer.create((FloatPointer) pointer); + break; + case HALF: + pointer = nPtr.asShortPointer(); + indexer = HalfIndexer.create((ShortPointer) pointer); + break; + case BFLOAT16: + pointer = nPtr.asShortPointer(); + indexer = Bfloat16Indexer.create((ShortPointer) pointer); + break; + case UINT64: + case LONG: + pointer = nPtr.asLongPointer(); + indexer = LongIndexer.create((LongPointer) pointer); + break; + } + } - AllocationPoint old = allocationPoint; - allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, dataType()), false); + workspaceGenerationId = getParentWorkspace().getGenerationId(); + } else { + this.ptrDataBuffer.expand(length); + val nPtr = new PagedPointer(this.ptrDataBuffer.primaryBuffer(), length); - Nd4j.getDeallocatorService().pickObject(this); - trackingPoint = allocationPoint.getObjectId(); - val oldLength = this.length; + switch (dataType()) { + case BOOL: + pointer = nPtr.asBoolPointer(); + indexer = BooleanIndexer.create((BooleanPointer) pointer); + break; + case UTF8: + case BYTE: + case UBYTE: + pointer = nPtr.asBytePointer(); + indexer = ByteIndexer.create((BytePointer) pointer); + break; + case UINT16: + case SHORT: + pointer = nPtr.asShortPointer(); + indexer = ShortIndexer.create((ShortPointer) pointer); + break; + case UINT32: + case INT: + pointer = nPtr.asIntPointer(); + indexer = IntIndexer.create((IntPointer) pointer); + break; + case DOUBLE: + pointer = nPtr.asDoublePointer(); + indexer = DoubleIndexer.create((DoublePointer) pointer); + break; + case FLOAT: + pointer = nPtr.asFloatPointer(); + indexer = FloatIndexer.create((FloatPointer) pointer); + break; + case HALF: + pointer = nPtr.asShortPointer(); + indexer = HalfIndexer.create((ShortPointer) pointer); + break; + case BFLOAT16: + pointer = nPtr.asShortPointer(); + indexer = Bfloat16Indexer.create((ShortPointer) pointer); + break; + case UINT64: + case LONG: + pointer = nPtr.asLongPointer(); + indexer = LongIndexer.create((LongPointer) pointer); + break; + } + } + + this.underlyingLength = length; this.length = length; - - // if original buffer had host pointer allocated, we'll reallocate host buffer as well - if (old.getHostPointer() != null) { - lazyAllocateHostPointer(); - } - - val context = AtomicAllocator.getInstance().getDeviceContext(); - NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(allocationPoint.getDevicePointer(), 0, length * elementSize, 0, context.getSpecialStream()); - - MemcpyDirection direction = MemcpyDirection.DEVICE_TO_DEVICE; - val perfD = PerformanceTracker.getInstance().helperStartTransaction(); - - if (old.isActualOnDeviceSide()) { - NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), old.getDevicePointer(), oldLength * elementSize, CudaConstants.cudaMemcpyDeviceToDevice, context.getSpecialStream()); - } else if (old.isActualOnHostSide()) { - NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), old.getHostPointer(), oldLength * elementSize, CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream()); - direction = MemcpyDirection.HOST_TO_DEVICE; - } - - context.getSpecialStream().synchronize(); - - PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD, allocationPoint.getNumberOfBytes(), direction); - - allocationPoint.tickDeviceWrite(); - - // we need to update length with new value now - //this.length = length; - if(isAttached()){ - // do nothing here, that's workspaces - } else{ - AtomicAllocator.getInstance().freeMemory(old); - } - return this; } @@ -1575,7 +1771,8 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda @Override protected void release() { if (!released) { - AtomicAllocator.getInstance().freeMemory(allocationPoint); + //AtomicAllocator.getInstance().freeMemory(allocationPoint);n + NativeOpsHolder.getInstance().getDeviceNativeOps().dbClose(allocationPoint.getPtrDataBuffer()); allocationPoint.setReleased(true); } released = true; diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaBfloat16DataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaBfloat16DataBuffer.java index 193a9e21c..145816a5e 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaBfloat16DataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaBfloat16DataBuffer.java @@ -46,6 +46,10 @@ public class CudaBfloat16DataBuffer extends BaseCudaDataBuffer { super(pointer, specialPointer, indexer, length); } + public CudaBfloat16DataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + /** * Base constructor * @@ -128,18 +132,6 @@ public class CudaBfloat16DataBuffer extends BaseCudaDataBuffer { super(data, copy, offset); } - public CudaBfloat16DataBuffer(byte[] data, long length) { - super(data, length, DataType.BFLOAT16); - } - - public CudaBfloat16DataBuffer(ByteBuffer buffer, long length) { - super(buffer, (int) length, DataType.BFLOAT16); - } - - public CudaBfloat16DataBuffer(ByteBuffer buffer, long length, long offset) { - super(buffer, length, offset, DataType.BFLOAT16); - } - @Override public void assign(long[] indices, double[] data, boolean contiguous, long inc) { diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaBoolDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaBoolDataBuffer.java index a1b498785..08dbd9f39 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaBoolDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaBoolDataBuffer.java @@ -50,6 +50,10 @@ public class CudaBoolDataBuffer extends BaseCudaDataBuffer { super(pointer, specialPointer, indexer, length); } + public CudaBoolDataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + /** * Base constructor * @@ -132,18 +136,6 @@ public class CudaBoolDataBuffer extends BaseCudaDataBuffer { super(data, copy, offset); } - public CudaBoolDataBuffer(byte[] data, long length) { - super(data, length, DataType.HALF); - } - - public CudaBoolDataBuffer(ByteBuffer buffer, long length) { - super(buffer, (int) length, DataType.HALF); - } - - public CudaBoolDataBuffer(ByteBuffer buffer, long length, long offset) { - super(buffer, length, offset, DataType.HALF); - } - @Override protected DataBuffer create(long length) { return new CudaBoolDataBuffer(length); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaByteDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaByteDataBuffer.java index 80fb7f804..d35b3c215 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaByteDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaByteDataBuffer.java @@ -49,6 +49,10 @@ public class CudaByteDataBuffer extends BaseCudaDataBuffer { super(pointer, specialPointer, indexer, length); } + public CudaByteDataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + /** * Base constructor * @@ -131,18 +135,6 @@ public class CudaByteDataBuffer extends BaseCudaDataBuffer { super(data, copy, offset); } - public CudaByteDataBuffer(byte[] data, long length) { - super(data, length, DataType.HALF); - } - - public CudaByteDataBuffer(ByteBuffer buffer, long length) { - super(buffer, (int) length, DataType.HALF); - } - - public CudaByteDataBuffer(ByteBuffer buffer, long length, long offset) { - super(buffer, length, offset, DataType.HALF); - } - @Override protected DataBuffer create(long length) { return new CudaByteDataBuffer(length); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaDoubleDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaDoubleDataBuffer.java index 8ccc3cf81..789b213f1 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaDoubleDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaDoubleDataBuffer.java @@ -49,6 +49,10 @@ public class CudaDoubleDataBuffer extends BaseCudaDataBuffer { super(pointer, specialPointer, indexer, length); } + public CudaDoubleDataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + /** * Base constructor * @@ -138,18 +142,6 @@ public class CudaDoubleDataBuffer extends BaseCudaDataBuffer { super(data, copy, offset); } - public CudaDoubleDataBuffer(byte[] data, long length) { - super(data, length, DataType.DOUBLE); - } - - public CudaDoubleDataBuffer(ByteBuffer buffer, long length) { - super(buffer, (int) length, DataType.DOUBLE); - } - - public CudaDoubleDataBuffer(ByteBuffer buffer, long length, long offset) { - super(buffer, length, offset, DataType.DOUBLE); - } - @Override protected DataBuffer create(long length) { return new CudaDoubleDataBuffer(length); @@ -210,14 +202,7 @@ public class CudaDoubleDataBuffer extends BaseCudaDataBuffer { this.length = n; this.elementSize = 8; - //wrappedBuffer = ByteBuffer.allocateDirect(length() * getElementSize()); - //wrappedBuffer.order(ByteOrder.nativeOrder()); - - this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, - new AllocationShape(length, elementSize, DataType.DOUBLE), false); - this.trackingPoint = allocationPoint.getObjectId(); - //this.wrappedBuffer = allocationPoint.getPointers().getHostPointer().asByteBuffer(); - //this.wrappedBuffer.order(ByteOrder.nativeOrder()); + this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, DataType.DOUBLE), false); setData(arr); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaFloatDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaFloatDataBuffer.java index c173e2745..f7f70bc75 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaFloatDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaFloatDataBuffer.java @@ -50,6 +50,10 @@ public class CudaFloatDataBuffer extends BaseCudaDataBuffer { super(pointer, specialPointer, indexer, length); } + public CudaFloatDataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + /** * Base constructor * @@ -133,19 +137,6 @@ public class CudaFloatDataBuffer extends BaseCudaDataBuffer { super(data, copy, offset); } - public CudaFloatDataBuffer(byte[] data, long length) { - super(data, length, DataType.FLOAT); - } - - public CudaFloatDataBuffer(ByteBuffer buffer, long length) { - super(buffer, (int) length, DataType.FLOAT); - } - - public CudaFloatDataBuffer(ByteBuffer buffer, long length, long offset) { - super(buffer, length, offset, DataType.FLOAT); - } - - @Override protected DataBuffer create(long length) { return new CudaFloatDataBuffer(length); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaHalfDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaHalfDataBuffer.java index 472e701c1..1fb55e73b 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaHalfDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaHalfDataBuffer.java @@ -49,6 +49,10 @@ public class CudaHalfDataBuffer extends BaseCudaDataBuffer { super(pointer, specialPointer, indexer, length); } + public CudaHalfDataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + /** * Base constructor * @@ -131,18 +135,6 @@ public class CudaHalfDataBuffer extends BaseCudaDataBuffer { super(data, copy, offset); } - public CudaHalfDataBuffer(byte[] data, long length) { - super(data, length, DataType.HALF); - } - - public CudaHalfDataBuffer(ByteBuffer buffer, long length) { - super(buffer, (int) length, DataType.HALF); - } - - public CudaHalfDataBuffer(ByteBuffer buffer, long length, long offset) { - super(buffer, length, offset, DataType.HALF); - } - @Override protected DataBuffer create(long length) { return new CudaHalfDataBuffer(length); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaIntDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaIntDataBuffer.java index 27c0c95e3..95a9c0ce9 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaIntDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaIntDataBuffer.java @@ -46,6 +46,10 @@ public class CudaIntDataBuffer extends BaseCudaDataBuffer { super(pointer, specialPointer, indexer, length); } + public CudaIntDataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + /** * Base constructor * @@ -106,11 +110,6 @@ public class CudaIntDataBuffer extends BaseCudaDataBuffer { super(data, copy, offset); } - - public CudaIntDataBuffer(byte[] data, int length) { - super(data, length, DataType.INT); - } - public CudaIntDataBuffer(double[] data) { super(data); } @@ -135,14 +134,6 @@ public class CudaIntDataBuffer extends BaseCudaDataBuffer { super(data, copy, offset); } - public CudaIntDataBuffer(ByteBuffer buffer, int length) { - super(buffer, length, DataType.INT); - } - - public CudaIntDataBuffer(ByteBuffer buffer, int length, long offset) { - super(buffer, length, offset, DataType.INT); - } - @Override protected DataBuffer create(long length) { return new CudaIntDataBuffer(length); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaLongDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaLongDataBuffer.java index 494148862..381ab5355 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaLongDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaLongDataBuffer.java @@ -16,12 +16,14 @@ package org.nd4j.linalg.jcublas.buffer; +import lombok.Data; import lombok.NonNull; import lombok.val; import org.bytedeco.javacpp.LongPointer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.indexer.Indexer; import org.bytedeco.javacpp.indexer.LongIndexer; +import org.nd4j.jita.allocator.impl.AllocationPoint; import org.nd4j.jita.allocator.impl.AllocationShape; import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.allocator.pointers.CudaPointer; @@ -30,6 +32,7 @@ import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.util.ArrayUtil; +import org.nd4j.nativeblas.NativeOpsHolder; import java.nio.ByteBuffer; import java.nio.ByteOrder; @@ -55,8 +58,18 @@ public class CudaLongDataBuffer extends BaseCudaDataBuffer { super(pointer, specialPointer, indexer, length); } + public CudaLongDataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + /** + * This constructor is special one - it's used for ShapeInfo + * @param hostPointer + * @param devicePointer + * @param numberOfElements + */ public CudaLongDataBuffer(@NonNull Pointer hostPointer, @NonNull Pointer devicePointer, long numberOfElements) { + super(); this.allocationMode = AllocationMode.MIXED_DATA_TYPES; this.offset = 0; this.originalOffset = 0; @@ -64,14 +77,15 @@ public class CudaLongDataBuffer extends BaseCudaDataBuffer { this.length = numberOfElements; initTypeAndSize(); + // creating empty native DataBuffer and filling it with pointers + ptrDataBuffer = NativeOpsHolder.getInstance().getDeviceNativeOps().allocateDataBuffer(0, DataType.INT64.toInt(), false); + NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetPrimaryBuffer(ptrDataBuffer, hostPointer, numberOfElements); + NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetSpecialBuffer(ptrDataBuffer, devicePointer, numberOfElements); + + // setting up java side of things this.pointer = new CudaPointer(hostPointer, numberOfElements).asLongPointer(); indexer = LongIndexer.create((LongPointer) this.pointer); - - this.allocationPoint = AtomicAllocator.getInstance().pickExternalBuffer(this); - - val pp = new PointersPair(devicePointer, this.pointer); - allocationPoint.setPointers(pp); - trackingPoint = allocationPoint.getObjectId(); + this.allocationPoint = new AllocationPoint(ptrDataBuffer, numberOfElements * DataType.INT64.width()); } /** @@ -179,19 +193,6 @@ public class CudaLongDataBuffer extends BaseCudaDataBuffer { super(data, copy, offset); } - public CudaLongDataBuffer(byte[] data, long length) { - super(data, length, DataType.LONG); - } - - public CudaLongDataBuffer(ByteBuffer buffer, long length) { - super(buffer, (int) length, DataType.LONG); - } - - public CudaLongDataBuffer(ByteBuffer buffer, long length, long offset) { - super(buffer, length, offset, DataType.LONG); - } - - @Override protected DataBuffer create(long length) { return new CudaLongDataBuffer(length); @@ -241,14 +242,7 @@ public class CudaLongDataBuffer extends BaseCudaDataBuffer { this.length = n; this.elementSize = 8; - //wrappedBuffer = ByteBuffer.allocateDirect(length() * getElementSize()); - //wrappedBuffer.order(ByteOrder.nativeOrder()); - - this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, - new AllocationShape(length, elementSize, DataType.LONG), false); - this.trackingPoint = allocationPoint.getObjectId(); - //this.wrappedBuffer = allocationPoint.getPointers().getHostPointer().asByteBuffer(); - //this.wrappedBuffer.order(ByteOrder.nativeOrder()); + this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, DataType.LONG), false); setData(arr); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaShortDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaShortDataBuffer.java index 9a67f56aa..645b06723 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaShortDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaShortDataBuffer.java @@ -49,6 +49,10 @@ public class CudaShortDataBuffer extends BaseCudaDataBuffer { super(pointer, specialPointer, indexer, length); } + public CudaShortDataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + /** * Base constructor * @@ -131,18 +135,6 @@ public class CudaShortDataBuffer extends BaseCudaDataBuffer { super(data, copy, offset); } - public CudaShortDataBuffer(byte[] data, long length) { - super(data, length, DataType.HALF); - } - - public CudaShortDataBuffer(ByteBuffer buffer, long length) { - super(buffer, (int) length, DataType.HALF); - } - - public CudaShortDataBuffer(ByteBuffer buffer, long length, long offset) { - super(buffer, length, offset, DataType.HALF); - } - @Override protected DataBuffer create(long length) { return new CudaShortDataBuffer(length); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUByteDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUByteDataBuffer.java index 7cc944850..5447ba043 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUByteDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUByteDataBuffer.java @@ -49,6 +49,10 @@ public class CudaUByteDataBuffer extends BaseCudaDataBuffer { super(pointer, specialPointer, indexer, length); } + public CudaUByteDataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + /** * Base constructor * @@ -131,18 +135,6 @@ public class CudaUByteDataBuffer extends BaseCudaDataBuffer { super(data, copy, offset); } - public CudaUByteDataBuffer(byte[] data, long length) { - super(data, length, DataType.HALF); - } - - public CudaUByteDataBuffer(ByteBuffer buffer, long length) { - super(buffer, (int) length, DataType.HALF); - } - - public CudaUByteDataBuffer(ByteBuffer buffer, long length, long offset) { - super(buffer, length, offset, DataType.HALF); - } - @Override public void assign(long[] indices, double[] data, boolean contiguous, long inc) { diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUInt16DataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUInt16DataBuffer.java index 428cb5bcd..809363494 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUInt16DataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUInt16DataBuffer.java @@ -46,6 +46,10 @@ public class CudaUInt16DataBuffer extends BaseCudaDataBuffer { super(pointer, specialPointer, indexer, length); } + public CudaUInt16DataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + /** * Base constructor * @@ -128,18 +132,6 @@ public class CudaUInt16DataBuffer extends BaseCudaDataBuffer { super(data, copy, offset); } - public CudaUInt16DataBuffer(byte[] data, long length) { - super(data, length, DataType.UINT16); - } - - public CudaUInt16DataBuffer(ByteBuffer buffer, long length) { - super(buffer, (int) length, DataType.UINT16); - } - - public CudaUInt16DataBuffer(ByteBuffer buffer, long length, long offset) { - super(buffer, length, offset, DataType.UINT16); - } - @Override public void assign(long[] indices, double[] data, boolean contiguous, long inc) { diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUInt32DataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUInt32DataBuffer.java index cd34607ce..1595cfda3 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUInt32DataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUInt32DataBuffer.java @@ -46,6 +46,10 @@ public class CudaUInt32DataBuffer extends BaseCudaDataBuffer { super(pointer, specialPointer, indexer, length); } + public CudaUInt32DataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + /** * Base constructor * @@ -128,18 +132,6 @@ public class CudaUInt32DataBuffer extends BaseCudaDataBuffer { super(data, copy, offset); } - public CudaUInt32DataBuffer(byte[] data, long length) { - super(data, length, DataType.UINT32); - } - - public CudaUInt32DataBuffer(ByteBuffer buffer, long length) { - super(buffer, (int) length, DataType.UINT32); - } - - public CudaUInt32DataBuffer(ByteBuffer buffer, long length, long offset) { - super(buffer, length, offset, DataType.UINT32); - } - @Override public void assign(long[] indices, double[] data, boolean contiguous, long inc) { diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUInt64DataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUInt64DataBuffer.java index 0e413827c..a107a5d8c 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUInt64DataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUInt64DataBuffer.java @@ -42,6 +42,10 @@ public class CudaUInt64DataBuffer extends BaseCudaDataBuffer { super(pointer, indexer, length); } + public CudaUInt64DataBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + /** * Base constructor * @@ -128,18 +132,6 @@ public class CudaUInt64DataBuffer extends BaseCudaDataBuffer { super(data, copy, offset); } - public CudaUInt64DataBuffer(byte[] data, long length) { - super(data, length, DataType.UINT64); - } - - public CudaUInt64DataBuffer(ByteBuffer buffer, long length) { - super(buffer, (int) length, DataType.UINT64); - } - - public CudaUInt64DataBuffer(ByteBuffer buffer, long length, long offset) { - super(buffer, length, offset, DataType.UINT64); - } - @Override public void assign(long[] indices, double[] data, boolean contiguous, long inc) { diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUtf8Buffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUtf8Buffer.java new file mode 100644 index 000000000..50219f563 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUtf8Buffer.java @@ -0,0 +1,243 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.jcublas.buffer; + + +import lombok.Getter; +import lombok.NonNull; +import lombok.val; +import org.bytedeco.javacpp.BytePointer; +import org.bytedeco.javacpp.LongPointer; +import org.bytedeco.javacpp.Pointer; +import org.bytedeco.javacpp.indexer.Indexer; +import org.nd4j.base.Preconditions; +import org.nd4j.jita.allocator.impl.AtomicAllocator; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.memory.MemoryWorkspace; + +import java.io.UnsupportedEncodingException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collection; + +/** + * UTF-8 buffer + * + * @author Adam Gibson + */ +public class CudaUtf8Buffer extends BaseCudaDataBuffer { + + protected Collection references = new ArrayList<>(); + + @Getter + protected long numWords = 0; + + /** + * Meant for creating another view of a buffer + * + * @param pointer the underlying buffer to create a view from + * @param indexer the indexer for the pointer + * @param length the length of the view + */ + public CudaUtf8Buffer(Pointer pointer, Indexer indexer, long length) { + super(pointer, indexer, length); + } + + public CudaUtf8Buffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + + public CudaUtf8Buffer(long length) { + super(length); + } + + public CudaUtf8Buffer(long length, boolean initialize) { + super((length + 1) * 8, 1, initialize); + numWords = length; + } + + public CudaUtf8Buffer(long length, boolean initialize, MemoryWorkspace workspace) { + super((length + 1) * 8, 1, initialize, workspace); + numWords = length; + } + + public CudaUtf8Buffer(int[] ints, boolean copy, MemoryWorkspace workspace) { + super(ints, copy, workspace); + } + + public CudaUtf8Buffer(byte[] data, long numWords) { + super(data.length, 1, false); + + lazyAllocateHostPointer(); + + val bp = (BytePointer) pointer; + bp.put(data); + this.numWords = numWords; + } + + public CudaUtf8Buffer(double[] data, boolean copy) { + super(data, copy); + } + + public CudaUtf8Buffer(double[] data, boolean copy, long offset) { + super(data, copy, offset); + } + + public CudaUtf8Buffer(float[] data, boolean copy) { + super(data, copy); + } + + public CudaUtf8Buffer(long[] data, boolean copy) { + super(data, copy); + } + + public CudaUtf8Buffer(long[] data, boolean copy, MemoryWorkspace workspace) { + super(data, copy); + } + + public CudaUtf8Buffer(float[] data, boolean copy, long offset) { + super(data, copy, offset); + } + + public CudaUtf8Buffer(int[] data, boolean copy, long offset) { + super(data, copy, offset); + } + + public CudaUtf8Buffer(int length, int elementSize) { + super(length, elementSize); + } + + public CudaUtf8Buffer(int length, int elementSize, long offset) { + super(length, elementSize, offset); + } + + public CudaUtf8Buffer(DataBuffer underlyingBuffer, long length, long offset) { + super(underlyingBuffer, length, offset); + this.numWords = length; + + Preconditions.checkArgument(((CudaUtf8Buffer) underlyingBuffer).numWords == numWords, "String array can't be a view"); + } + + public CudaUtf8Buffer(@NonNull Collection strings) { + super(CudaUtf8Buffer.stringBufferRequiredLength(strings), 1, false); + lazyAllocateHostPointer(); + + // at this point we should have fully allocated buffer, time to fill length + val headerLength = (strings.size() + 1) * 8; + val headerPointer = new LongPointer(this.pointer); + val dataPointer = new BytePointer(this.pointer); + + numWords = strings.size(); + + long cnt = 0; + long currentLength = 0; + for (val s: strings) { + headerPointer.put(cnt++, currentLength); + val length = s.length(); + val chars = s.toCharArray(); + + // putting down chars + for (int e = 0; e < length; e++) { + val b = (byte) chars[e]; + val idx = headerLength + currentLength + e; + dataPointer.put(idx, b); + } + + currentLength += length; + } + headerPointer.put(cnt, currentLength); + allocationPoint.tickHostWrite(); + } + + public String getString(long index) { + if (index > numWords) + throw new IllegalArgumentException("Requested index [" + index + "] is above actual number of words stored: [" + numWords + "]"); + + val headerPointer = new LongPointer(this.pointer); + val dataPointer = (BytePointer) (this.pointer); + + val start = headerPointer.get(index); + val end = headerPointer.get(index+1); + + if (end - start > Integer.MAX_VALUE) + throw new IllegalStateException("Array is too long for Java"); + + val dataLength = (int) (end - start); + val bytes = new byte[dataLength]; + + val headerLength = (numWords + 1) * 8; + + for (int e = 0; e < dataLength; e++) { + val idx = headerLength + start + e; + bytes[e] = dataPointer.get(idx); + } + + try { + return new String(bytes, "UTF-8"); + } catch (UnsupportedEncodingException e) { + throw new RuntimeException(e); + } + } + + @Override + protected DataBuffer create(long length) { + return new CudaUtf8Buffer(length); + } + + @Override + public DataBuffer create(double[] data) { + throw new UnsupportedOperationException(); + } + + @Override + public DataBuffer create(float[] data) { + throw new UnsupportedOperationException(); + } + + @Override + public DataBuffer create(int[] data) { + throw new UnsupportedOperationException(); + } + + private static long stringBufferRequiredLength(@NonNull Collection strings) { + // header size first + long size = (strings.size() + 1) * 8; + + for (val s:strings) + size += s.length(); + + return size; + } + + public void put(long index, Pointer pointer) { + throw new UnsupportedOperationException(); + //references.add(pointer); + //((LongIndexer) indexer).put(index, pointer.address()); + } + + /** + * Initialize the opType of this buffer + */ + @Override + protected void initTypeAndSize() { + elementSize = 1; + type = DataType.UTF8; + } + + +} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/factory/CudaDataBufferFactory.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/factory/CudaDataBufferFactory.java index 72e089e45..5083a2bf9 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/factory/CudaDataBufferFactory.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/factory/CudaDataBufferFactory.java @@ -24,15 +24,11 @@ import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.indexer.*; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.buffer.LongBuffer; -import org.nd4j.linalg.api.buffer.Utf8Buffer; import org.nd4j.linalg.api.buffer.factory.DataBufferFactory; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.jcublas.buffer.*; import org.nd4j.linalg.util.ArrayUtil; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import java.nio.ByteBuffer; @@ -64,6 +60,42 @@ public class CudaDataBufferFactory implements DataBufferFactory { return allocationMode; } + @Override + public DataBuffer create(ByteBuffer underlyingBuffer, DataType dataType, long length, long offset) { + switch (dataType) { + case DOUBLE: + return new CudaDoubleDataBuffer(underlyingBuffer, dataType, length, offset); + case FLOAT: + return new CudaFloatDataBuffer(underlyingBuffer, dataType, length, offset); + case HALF: + return new CudaHalfDataBuffer(underlyingBuffer, dataType, length, offset); + case BFLOAT16: + return new CudaBfloat16DataBuffer(underlyingBuffer, dataType, length, offset); + case LONG: + return new CudaLongDataBuffer(underlyingBuffer, dataType, length, offset); + case INT: + return new CudaIntDataBuffer(underlyingBuffer, dataType, length, offset); + case SHORT: + return new CudaShortDataBuffer(underlyingBuffer, dataType, length, offset); + case UBYTE: + return new CudaUByteDataBuffer(underlyingBuffer, dataType, length, offset); + case UINT16: + return new CudaUInt16DataBuffer(underlyingBuffer, dataType, length, offset); + case UINT32: + return new CudaUInt32DataBuffer(underlyingBuffer, dataType, length, offset); + case UINT64: + return new CudaUInt64DataBuffer(underlyingBuffer, dataType, length, offset); + case BYTE: + return new CudaByteDataBuffer(underlyingBuffer, dataType, length, offset); + case BOOL: + return new CudaBoolDataBuffer(underlyingBuffer, dataType, length, offset); + case UTF8: + return new CudaUtf8Buffer(underlyingBuffer, dataType, length, offset); + default: + throw new IllegalStateException("Unknown datatype used: [" + dataType + "]"); + } + } + @Override public DataBuffer create(DataBuffer underlyingBuffer, long offset, long length) { switch (underlyingBuffer.dataType()) { @@ -94,7 +126,7 @@ public class CudaDataBufferFactory implements DataBufferFactory { case BOOL: return new CudaBoolDataBuffer(underlyingBuffer, length, offset); case UTF8: - return new Utf8Buffer(underlyingBuffer, length, offset); + return new CudaUtf8Buffer(underlyingBuffer, length, offset); default: throw new ND4JIllegalStateException("Unknown data buffer type: " + underlyingBuffer.dataType().toString()); } @@ -169,27 +201,6 @@ public class CudaDataBufferFactory implements DataBufferFactory { return new CudaIntDataBuffer(data, copy, workspace); } - @Override - public DataBuffer createInt(long offset, ByteBuffer buffer, int length) { - return new CudaIntDataBuffer(buffer, length, offset); - } - - @Override - public DataBuffer createFloat(long offset, ByteBuffer buffer, int length) { - return new CudaFloatDataBuffer(buffer, length, offset); - } - - @Override - public DataBuffer createDouble(long offset, ByteBuffer buffer, int length) { - return new CudaDoubleDataBuffer(buffer, length, offset); - } - - - @Override - public DataBuffer createLong(ByteBuffer buffer, int length) { - return new CudaLongDataBuffer(buffer, length); - } - @Override public DataBuffer createDouble(long offset, int length) { return new CudaDoubleDataBuffer(length, 8, offset); @@ -315,21 +326,6 @@ public class CudaDataBufferFactory implements DataBufferFactory { return new CudaIntDataBuffer(data, copy, offset); } - @Override - public DataBuffer createInt(ByteBuffer buffer, int length) { - return new CudaIntDataBuffer(buffer, length); - } - - @Override - public DataBuffer createFloat(ByteBuffer buffer, int length) { - return new CudaFloatDataBuffer(buffer, length); - } - - @Override - public DataBuffer createDouble(ByteBuffer buffer, int length) { - return new CudaDoubleDataBuffer(buffer, length); - } - @Override public DataBuffer createDouble(long length) { return new CudaDoubleDataBuffer(length); @@ -384,6 +380,8 @@ public class CudaDataBufferFactory implements DataBufferFactory { return new CudaHalfDataBuffer(length, initialize); case BOOL: return new CudaBoolDataBuffer(length, initialize); + case UTF8: + return new CudaUtf8Buffer(length, true); default: throw new UnsupportedOperationException("Unknown data type: [" + dataType + "]"); } @@ -581,16 +579,6 @@ public class CudaDataBufferFactory implements DataBufferFactory { return new CudaDoubleDataBuffer(data); } - @Override - public DataBuffer createDouble(byte[] data, int length) { - return new CudaDoubleDataBuffer(data, length); - } - - @Override - public DataBuffer createFloat(byte[] data, int length) { - return new CudaFloatDataBuffer(data, length); - } - @Override public DataBuffer createFloat(double[] data) { return new CudaFloatDataBuffer(ArrayUtil.toFloats(data)); @@ -969,18 +957,6 @@ public class CudaDataBufferFactory implements DataBufferFactory { return new CudaHalfDataBuffer(data); } - /** - * Creates a half-precision data buffer - * - * @param offset - * @param data the data to create the buffer from - * @param length - * @return the new buffer - */ - @Override - public DataBuffer createHalf(long offset, byte[] data, int length) { - return new CudaHalfDataBuffer(ArrayUtil.toFloatArray(data), true, offset); - } /** * Creates a half-precision data buffer @@ -994,30 +970,6 @@ public class CudaDataBufferFactory implements DataBufferFactory { return new CudaHalfDataBuffer(length); } - /** - * Creates a half-precision data buffer - * - * @param buffer - * @param length - * @return the new buffer - */ - @Override - public DataBuffer createHalf(ByteBuffer buffer, int length) { - return new CudaHalfDataBuffer(buffer, length); - } - - /** - * Creates a half-precision data buffer - * - * @param data - * @param length - * @return - */ - @Override - public DataBuffer createHalf(byte[] data, int length) { - return new CudaHalfDataBuffer(data, length); - } - @Override public DataBuffer createDouble(long length, boolean initialize, MemoryWorkspace workspace) { return new CudaDoubleDataBuffer(length, initialize, workspace); @@ -1124,4 +1076,7 @@ public class CudaDataBufferFactory implements DataBufferFactory { return new CudaLongDataBuffer(length, initialize, workspace); } + public DataBuffer createUtf8Buffer(byte[] data, long product) { + return new CudaUtf8Buffer(data, product); + } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/compression/CudaThreshold.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/compression/CudaThreshold.java index 19e8f8df6..f9cbb1794 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/compression/CudaThreshold.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/compression/CudaThreshold.java @@ -24,28 +24,18 @@ import org.apache.commons.math3.util.FastMath; import org.bytedeco.javacpp.*; import org.nd4j.compression.impl.AbstractCompressor; import org.nd4j.jita.allocator.impl.AtomicAllocator; -import org.nd4j.jita.conf.CudaEnvironment; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataTypeEx; -import org.nd4j.linalg.api.buffer.IntBuffer; -import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.compression.CompressedDataBuffer; -import org.nd4j.linalg.compression.CompressionDescriptor; import org.nd4j.linalg.compression.CompressionType; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.indexing.conditions.AbsValueGreaterThan; -import org.nd4j.linalg.indexing.conditions.Conditions; -import org.nd4j.linalg.jcublas.buffer.CudaDoubleDataBuffer; import org.nd4j.linalg.jcublas.context.CudaContext; -import org.nd4j.linalg.ops.transforms.Transforms; -import org.nd4j.nativeblas.NativeOps; import org.nd4j.nativeblas.NativeOpsHolder; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; /** diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index dca1a61a9..f1bbb6d04 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -24,10 +24,8 @@ import lombok.val; import lombok.var; import org.bytedeco.javacpp.*; import org.bytedeco.javacpp.indexer.LongIndexer; -import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper; import org.nd4j.base.Preconditions; -import org.nd4j.jita.allocator.impl.AllocationPoint; import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.allocator.pointers.CudaPointer; import org.nd4j.jita.allocator.tad.DeviceTADManager; @@ -36,7 +34,6 @@ import org.nd4j.jita.conf.CudaEnvironment; import org.nd4j.linalg.api.buffer.BaseDataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.buffer.Utf8Buffer; import org.nd4j.linalg.api.environment.Nd4jEnvironment; import org.nd4j.linalg.api.memory.pointers.PagedPointer; import org.nd4j.linalg.api.ndarray.INDArray; @@ -50,6 +47,7 @@ import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate; import org.nd4j.linalg.api.ops.impl.summarystats.Variance; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.CopyOp; import org.nd4j.linalg.api.ops.performance.PerformanceTracker; +import org.nd4j.linalg.api.ops.util.PrintVariable; import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.Shape; @@ -58,13 +56,13 @@ import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper; import org.nd4j.linalg.api.shape.options.ArrayType; import org.nd4j.linalg.cache.TADManager; import org.nd4j.linalg.compression.ThresholdCompression; -import org.nd4j.linalg.exception.ND4JIllegalArgumentException; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JOpProfilerException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.jcublas.buffer.AddressRetriever; import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer; import org.nd4j.linalg.jcublas.buffer.CudaLongDataBuffer; +import org.nd4j.linalg.jcublas.buffer.CudaUtf8Buffer; import org.nd4j.linalg.jcublas.context.CudaContext; import org.nd4j.linalg.primitives.AtomicBoolean; import org.nd4j.linalg.primitives.Pair; @@ -131,12 +129,10 @@ public class CudaExecutioner extends DefaultOpExecutioner { val dimension = op.dimensions().toIntVector(); -// validateDataType(Nd4j.dataType(), op); - if (extraz.get() == null) extraz.set(new PointerPointer(32)); - CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y()); + val context = AtomicAllocator.getInstance().getDeviceContext(); if (CudaEnvironment.getInstance().getConfiguration().isDebug()) lastOp.set(op.opName()); @@ -146,9 +142,10 @@ public class CudaExecutioner extends DefaultOpExecutioner { Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer()); - Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context); - Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context); - Pointer z = AtomicAllocator.getInstance().getPointer(op.z(), context); + val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val y = op.y() == null ? null : ((BaseCudaDataBuffer) op.y().data()).getOpaqueDataBuffer(); + val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer(); + Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context); Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension); @@ -185,23 +182,18 @@ public class CudaExecutioner extends DefaultOpExecutioner { switch (op.getOpType()) { case BROADCAST: nativeOps.execBroadcast(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.x().shapeInfoDataBuffer()), x, (LongPointer) xShapeInfo, - null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.y().shapeInfoDataBuffer()), y, (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(),context), - null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.z().shapeInfoDataBuffer()), z, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), - null, - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null); + x, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.x().shapeInfoDataBuffer()), (LongPointer) xShapeInfo, + y, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.y().shapeInfoDataBuffer()), (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(),context), + z, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.z().shapeInfoDataBuffer()), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.dimensions().shapeInfoDataBuffer()), (LongPointer) AtomicAllocator.getInstance().getPointer(op.dimensions().shapeInfoDataBuffer(), context)); break; case BROADCAST_BOOL: nativeOps.execBroadcastBool(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.x().shapeInfoDataBuffer()), x, (LongPointer) xShapeInfo, - null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.y().shapeInfoDataBuffer()), y, (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(),context), - null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.z().shapeInfoDataBuffer()), z, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), - null, null, - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null); + x, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.x().shapeInfoDataBuffer()), (LongPointer) xShapeInfo, + y, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.y().shapeInfoDataBuffer()), (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(),context), + z, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.z().shapeInfoDataBuffer()), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), + null, + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.dimensions().shapeInfoDataBuffer()), (LongPointer) AtomicAllocator.getInstance().getPointer(op.dimensions().shapeInfoDataBuffer(), context)); break; default: throw new UnsupportedOperationException("Unknown op type: " + op.getOpType()); @@ -210,9 +202,6 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - - AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); - profilingConfigurableHookOut(op, st); return op.z(); @@ -252,7 +241,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) + " contains element that higher then rank of op.X: [" + op.x().rank() + "]"); - val context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y()); + val context = AtomicAllocator.getInstance().getDeviceContext(); if (CudaEnvironment.getInstance().getConfiguration().isDebug()) lastOp.set(op.opName()); @@ -269,7 +258,6 @@ public class CudaExecutioner extends DefaultOpExecutioner { DataBuffer offsets = tadBuffers.getSecond(); Pointer devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context); - Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context); Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context); if (extraz.get() == null) @@ -333,150 +321,118 @@ public class CudaExecutioner extends DefaultOpExecutioner { Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(argsType), context) : null; Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context); //AtomicAllocator.getInstance().getPointer(Nd4j.createBuffer(dimension), context); - if (op instanceof Variance) { - if (ret.isScalar()) { - nativeOps.execSummaryStatsScalar(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - extraArgs, - null, (LongPointer) hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer()), - ((Variance) op).isBiasCorrected()); + val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val y = op.y() == null ? null : ((BaseCudaDataBuffer) op.y().data()).getOpaqueDataBuffer(); + val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer(); - AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); - } else { - nativeOps.execSummaryStatsTad(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - extraArgs, - null, (LongPointer) hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), - null, - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null, ((Variance) op).isBiasCorrected(), - (LongPointer) devTadShapeInfo, - (LongPointer) devTadOffsets); - - AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); - } - } else if (op.y() != null) { - if (op.isComplexAccumulation()) { - - val dT = new LongPointerWrapper(devTadOffsets); - val yT = new LongPointerWrapper(yDevTadOffsets); - - nativeOps.execReduce3All(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - extraArgs, - null, (LongPointer) hostYShapeInfo, AtomicAllocator.getInstance().getPointer(op.y(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(),context), - null, (LongPointer) hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), - null, - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null, - (LongPointer) devTadShapeInfo, - dT, - (LongPointer) yDevTadShapeInfo, - yT); - - AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); - } else if (ret.isScalar()) { - nativeOps.execReduce3Scalar(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - extraArgs, - null, (LongPointer) hostYShapeInfo, AtomicAllocator.getInstance().getPointer(op.y(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), - null, (LongPointer) hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context)); - AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); - } else { - nativeOps.execReduce3Tad(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - extraArgs, - null, (LongPointer) hostYShapeInfo, AtomicAllocator.getInstance().getPointer(op.y(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), - null, (LongPointer) hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), - (IntPointer) op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null, - (LongPointer) devTadShapeInfo, (LongPointer) devTadOffsets, (LongPointer) yDevTadShapeInfo, (LongPointer) yDevTadOffsets); - - AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); - } + if (op instanceof Variance) { + if (ret.isScalar()) { + nativeOps.execSummaryStatsScalar(xShapeInfoHostPointer, op.opNum(), + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + extraArgs, + z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer()), + ((Variance) op).isBiasCorrected()); } else { + nativeOps.execSummaryStatsTad(xShapeInfoHostPointer, op.opNum(), + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + extraArgs, + z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, + ((Variance) op).isBiasCorrected(), + (LongPointer) devTadShapeInfo, (LongPointer) devTadOffsets); + } + } else if (op.y() != null) { + if (op.isComplexAccumulation()) { + + val dT = new LongPointerWrapper(devTadOffsets); + val yT = new LongPointerWrapper(yDevTadOffsets); + + nativeOps.execReduce3All(xShapeInfoHostPointer, op.opNum(), + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + extraArgs, + y, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(),context), + z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, + (LongPointer) devTadShapeInfo, dT, + (LongPointer) yDevTadShapeInfo, yT); + } else if (ret.isScalar()) { + nativeOps.execReduce3Scalar(xShapeInfoHostPointer, op.opNum(), + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + extraArgs, + y, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), + z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context)); + } else { + nativeOps.execReduce3Tad(xShapeInfoHostPointer, op.opNum(), + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + extraArgs, + y, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), + z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, + (LongPointer) devTadShapeInfo, (LongPointer) devTadOffsets, (LongPointer) yDevTadShapeInfo, (LongPointer) yDevTadOffsets); + } + } else { if (ret.isScalar()) { switch (op.getOpType()) { case REDUCE_FLOAT: nativeOps.execReduceFloat(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo,(LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer())); + z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer())); break; case REDUCE_BOOL: nativeOps.execReduceBool(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer())); + z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer())); break; case REDUCE_LONG: nativeOps.execReduceLong(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer())); + z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer())); break; case REDUCE_SAME: nativeOps.execReduceSame(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer())); + z, (LongPointer) hostZShapeInfo,(LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer())); break; default: throw new UnsupportedOperationException(); } - - AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); } else { switch (op.getOpType()) { case REDUCE_FLOAT: nativeOps.execReduceFloat2(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), - (IntPointer) op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null); + z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_BOOL: nativeOps.execReduceBool2(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), - (IntPointer) op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null); + z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_SAME: nativeOps.execReduceSame2(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), - (IntPointer) op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null); + z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_LONG: nativeOps.execReduceLong2(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), - (IntPointer) op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null); + z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; default: throw new UnsupportedOperationException(); } - - AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); } } @@ -610,7 +566,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (CudaEnvironment.getInstance().getConfiguration().isDebug()) lastOp.set(op.opName()); - val context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y()); + val context = AtomicAllocator.getInstance().getDeviceContext(); val hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()); @@ -619,10 +575,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { val hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer()); - val x = AtomicAllocator.getInstance().getPointer(op.x(), context); val xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context); - - val z = AtomicAllocator.getInstance().getPointer(op.z(), context); val zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context); Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension); @@ -644,22 +597,19 @@ public class CudaExecutioner extends DefaultOpExecutioner { Pointer dimensionPointer = AtomicAllocator.getInstance() .getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context); + val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val y = op.y() == null ? null : ((BaseCudaDataBuffer) op.y().data()).getOpaqueDataBuffer(); + val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer(); - nativeOps.execIndexReduce(xShapeInfoHostPointer, - op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - extraArgs, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, - (IntPointer) op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null); + nativeOps.execIndexReduce(xShapeInfoHostPointer, op.opNum(), + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + extraArgs, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); - profilingConfigurableHookOut(op, st); return op.z(); @@ -681,7 +631,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { super.exec(op); if (op.z() != null) - AtomicAllocator.getInstance().tickHostWrite(op.z()); + throw new UnsupportedOperationException("Pew-pew"); + //AtomicAllocator.getInstance().tickHostWrite(op.z()); + return null; } @@ -731,12 +683,11 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (extraz.get() == null) extraz.set(new PointerPointer(32)); - CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y()); + val context = AtomicAllocator.getInstance().getDeviceContext(); if (CudaEnvironment.getInstance().getConfiguration().isDebug()) lastOp.set(op.opName()); - Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context); Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context); @@ -780,35 +731,32 @@ public class CudaExecutioner extends DefaultOpExecutioner { devTadShapeInfoZ, // 12 devTadOffsetsZ); // 13 - Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context); Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context); - Pointer z = AtomicAllocator.getInstance().getPointer(op.z(), context); Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context); Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(op.getDimension()), context); + val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val y = op.y() == null ? null : ((BaseCudaDataBuffer) op.y().data()).getOpaqueDataBuffer(); + val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer(); + //log.info("X: {}; Y: {}; Z: {}; dTS: {}, dTO: {}; dTSz: {}; dTOz: {};", x.address(), y.address(), z.address(), devTadShapeInfo.address(), devTadOffsets.address(), devTadShapeInfoZ.address(), devTadOffsetsZ.address()); switch (op.getOpType()) { case BROADCAST: nativeOps.execBroadcast(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - null, (LongPointer) hostYShapeInfo, y, (LongPointer) yShapeInfo, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, - null, - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null); + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case BROADCAST_BOOL: nativeOps.execBroadcastBool(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - null, (LongPointer) hostYShapeInfo, y, (LongPointer) yShapeInfo, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, - null, null, - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null); + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + null, + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; default: throw new UnsupportedOperationException("Unknown opType: " + op.getOpType()); @@ -817,8 +765,6 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); - profilingConfigurableHookOut(op, st); return null; @@ -851,12 +797,10 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (dimension[i] >= op.x().rank() && dimension[i] != Integer.MAX_VALUE) throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) + " contains element that higher then rank of op.X: [" + op.x().rank() + "]"); - CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z().isScalar() ? null : op.z(), op.x(), op.y()); + val context = AtomicAllocator.getInstance().getDeviceContext(); - Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context); Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context); - Pointer extraArgs = op.extraArgs() != null - ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.x().dataType()), context) : null; + Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.x().dataType()), context) : null; val hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()); val hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer()); @@ -873,9 +817,12 @@ public class CudaExecutioner extends DefaultOpExecutioner { DataBuffer offsets = tadBuffers.getSecond(); Pointer devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context); - val z = AtomicAllocator.getInstance().getPointer(op.z(), context); val zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context); + val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val y = op.y() == null ? null : ((BaseCudaDataBuffer) op.y().data()).getOpaqueDataBuffer(); + val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer(); + PointerPointer xShapeInfoHostPointer = extraz.get().put( AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), @@ -884,28 +831,22 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (op.z().isScalar() || dimension == null || dimension[0] == Integer.MAX_VALUE) { nativeOps.execIndexReduceScalar(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo); - - AtomicAllocator.getInstance().registerAction(context, null, op.x(), op.y()); + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); } else { - Arrays.sort(dimension); + if (dimension != null && dimension.length > 1) + Arrays.sort(dimension); //long dimensionPointer = AtomicAllocator.getInstance().getPointer(Nd4j.createBuffer(dimension), context); Pointer dimensionPointer = AtomicAllocator.getInstance() .getHostPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension)); nativeOps.execIndexReduce(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, - dimensionPointer, - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null); - - AtomicAllocator.getInstance().registerAction(context, null, op.x(), op.y()); + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); } if (nativeOps.lastErrorCode() != 0) @@ -919,7 +860,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { protected CudaContext invoke(ReduceOp op, int[] dimension) { - CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y()); + val context = AtomicAllocator.getInstance().getDeviceContext(); if(op instanceof BaseReduceOp && ((BaseReduceOp)op).isEmptyReduce()){ //Edge case for TF import compatibility: [x,y].reduce(empty) = [x,y] @@ -962,7 +903,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (dimension == null ) dimension = new int[] {Integer.MAX_VALUE}; - if (dimension.length > 1) + if (dimension != null && dimension.length > 1) Arrays.sort(dimension); for (int i = 0; i < dimension.length; i++) @@ -981,7 +922,6 @@ public class CudaExecutioner extends DefaultOpExecutioner { val offsets = op.x().isEmpty() ? null : tadBuffers.getSecond(); val devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context); - Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context); Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context); long[] retShape = Shape.reductionShape(op.x(), dimension, true, op.isKeepDims()); @@ -1044,139 +984,114 @@ public class CudaExecutioner extends DefaultOpExecutioner { xShapeInfoHostPointer.put(13, yDevTadOffsets); } - val z = AtomicAllocator.getInstance().getPointer(op.z(), context); val zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context); - //log.info("Op.X address: {};", x.address()); + val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val y = op.y() == null ? null : ((BaseCudaDataBuffer) op.y().data()).getOpaqueDataBuffer(); + val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer(); op.validateDataTypes(); if (op.z().isScalar()) { if (op instanceof Variance) { nativeOps.execSummaryStatsScalar(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, ((Variance) op).isBiasCorrected()); - AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); } else if (op.y() != null) { - Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context); Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context); nativeOps.execReduce3Scalar(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostYShapeInfo, y, (LongPointer) yShapeInfo, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo); - - AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); + y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); } else { switch (op.getOpType()) { case REDUCE_FLOAT: nativeOps.execReduceFloat(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo); + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); break; case REDUCE_BOOL: nativeOps.execReduceBool(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo); + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); break; case REDUCE_SAME: nativeOps.execReduceSame(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo); + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); break; case REDUCE_LONG: nativeOps.execReduceLong(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo); + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo); break; default: throw new UnsupportedOperationException(); } - - AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); } } else { val dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context); //AtomicAllocator.getInstance().getPointer(Nd4j.createBuffer(dimension), context); if (op.y() != null) { - val y = AtomicAllocator.getInstance().getPointer(op.y(), context); val yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context); nativeOps.execReduce3Tad(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostYShapeInfo, y, (LongPointer) yShapeInfo, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, - (IntPointer) op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - dimensionPointer, null, (LongPointer) devTadShapeInfo, (LongPointer) devTadOffsets, (LongPointer) yDevTadShapeInfo, (LongPointer) yDevTadOffsets); + y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, + (LongPointer) devTadShapeInfo, (LongPointer) devTadOffsets, (LongPointer) yDevTadShapeInfo, (LongPointer) yDevTadOffsets); } else { if (op instanceof Variance) { nativeOps.execSummaryStatsTad(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, - (IntPointer) op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, ((Variance) op).isBiasCorrected(), - (LongPointer) devTadShapeInfo, - (LongPointer) devTadOffsets); + (LongPointer) devTadShapeInfo, (LongPointer) devTadOffsets); } else { switch (op.getOpType()) { case REDUCE_FLOAT: nativeOps.execReduceFloat2(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, - (IntPointer) op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null); + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_SAME: nativeOps.execReduceSame2(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, - (IntPointer) op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null); + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_BOOL: nativeOps.execReduceBool2(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, - (IntPointer) op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null); + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_LONG: nativeOps.execReduceLong2(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, extraArgs, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, - (IntPointer) op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null); + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; default: throw new UnsupportedOperationException(); } } } - - AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); } if (nativeOps.lastErrorCode() != 0) @@ -1193,9 +1108,10 @@ public class CudaExecutioner extends DefaultOpExecutioner { protected CudaContext intercept(ScalarOp op, int[] dimension) { long st = profilingConfigurableHookIn(op); - Arrays.sort(dimension); + if (dimension != null && dimension.length > 1) + Arrays.sort(dimension); - CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y()); + val context = AtomicAllocator.getInstance().getDeviceContext(); if (CudaEnvironment.getInstance().getConfiguration().isDebug()) lastOp.set(op.opName()); @@ -1204,9 +1120,6 @@ public class CudaExecutioner extends DefaultOpExecutioner { val hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer()); val hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer()); - val x = AtomicAllocator.getInstance().getPointer(op.x(), context); - val y = AtomicAllocator.getInstance().getPointer(op.y(), context); - val z = AtomicAllocator.getInstance().getPointer(op.z(), context); val xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context); val yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context); val zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context); @@ -1239,30 +1152,28 @@ public class CudaExecutioner extends DefaultOpExecutioner { val dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context); + val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val y = op.y() == null ? null : ((BaseCudaDataBuffer) op.y().data()).getOpaqueDataBuffer(); + val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer(); + switch (op.getOpType()) { case SCALAR: nativeOps.execScalarTad(extraPointers, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, - null, (LongPointer) hostYShapeInfo, y, (LongPointer) yShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, extraArgs, - null, - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null, + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, (LongPointer) devTadShapeInfo, (LongPointer) devTadOffsets, (LongPointer) devTadShapeInfoZ, (LongPointer) devTadOffsetsZ); break; case SCALAR_BOOL: nativeOps.execScalarBoolTad(extraPointers, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, - null, (LongPointer) hostYShapeInfo, y, (LongPointer) yShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, extraArgs, - null, - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - AtomicAllocator.getInstance().getPointer(op.dimensions(), context), - null, + ((BaseCudaDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, (LongPointer) devTadShapeInfo, (LongPointer) devTadOffsets, (LongPointer) devTadShapeInfoZ, (LongPointer) devTadOffsetsZ); break; @@ -1273,8 +1184,6 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - AtomicAllocator.getInstance().getFlowController().registerAction(context, op.z(), op.x(), op.y()); - profilingConfigurableHookOut(op, st); return null; @@ -1322,17 +1231,15 @@ public class CudaExecutioner extends DefaultOpExecutioner { return null; } - CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y()); + val context = AtomicAllocator.getInstance().getDeviceContext(); val hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()); val hostYShapeInfo = op.scalar() == null ? null : AddressRetriever.retrieveHostPointer(op.scalar().shapeInfoDataBuffer()); val hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer()); - Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context); Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context); Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.getOpType() == Op.Type.SCALAR_BOOL ? op.x().dataType() : op.z().dataType()), context) : null; - Pointer z = AtomicAllocator.getInstance().getPointer(op.z(), context); Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context); PointerPointer xShapeInfoHostPointer = extraz.get().put( @@ -1341,19 +1248,23 @@ public class CudaExecutioner extends DefaultOpExecutioner { context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, null, null); + val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val y = op.scalar() == null ? null : ((BaseCudaDataBuffer) op.scalar().data()).getOpaqueDataBuffer(); + val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer(); + switch (op.getOpType()) { case SCALAR_BOOL: nativeOps.execScalarBool(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, - null, (LongPointer) hostYShapeInfo, AtomicAllocator.getInstance().getPointer(op.scalar(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.scalar().shapeInfoDataBuffer(), context), + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + y, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.scalar().shapeInfoDataBuffer(), context), extraArgs); break; case SCALAR: nativeOps.execScalar(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, - null, (LongPointer) hostYShapeInfo, AtomicAllocator.getInstance().getPointer(op.scalar(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.scalar().shapeInfoDataBuffer(), context), + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, + y, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.scalar().shapeInfoDataBuffer(), context), extraArgs); break; default: @@ -1363,8 +1274,6 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.scalar()); - profilingConfigurableHookOut(op, st); return null; @@ -1382,7 +1291,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (extraz.get() == null) extraz.set(new PointerPointer(32)); - CudaContext context = allocator.getFlowController().prepareAction(op.z(), op.x(), op.y()); + val context = allocator.getDeviceContext(); if (CudaEnvironment.getInstance().getConfiguration().isDebug()) lastOp.set(op.opName()); @@ -1390,7 +1299,6 @@ public class CudaExecutioner extends DefaultOpExecutioner { // special temp array for IsMax along dimension INDArray ret = null; - Pointer x = allocator.getPointer(op.x(), context); Pointer xShapeInfo = allocator.getPointer(op.x().shapeInfoDataBuffer(), context); @@ -1426,7 +1334,6 @@ public class CudaExecutioner extends DefaultOpExecutioner { op.validateDataTypes(experimentalMode.get()); - Pointer z = allocator.getPointer(op.z(), context); Pointer zShapeInfo = allocator.getPointer(op.z().shapeInfoDataBuffer(), context); @@ -1453,31 +1360,30 @@ public class CudaExecutioner extends DefaultOpExecutioner { retHostShape); - + val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val y = op.y() == null ? null : ((BaseCudaDataBuffer) op.y().data()).getOpaqueDataBuffer(); + val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer(); if (op.y() != null) { - Pointer y = allocator.getPointer(op.y(), context); Pointer yShapeInfo = allocator.getPointer(op.y().shapeInfoDataBuffer(), context); if (op.x().length() != op.y().length() || op.x().length() != op.z().length()) throw new ND4JIllegalStateException("X, Y and Z arguments should have the same length for PairwiseTransform"); - ///log.info("X: {}; Y: {}; Z: {}; E: {};", x.address(), y.address(), z.address(), extraArgs != null ? extraArgs.address() : null); - switch (op.getOpType()) { case TRANSFORM_BOOL: case PAIRWISE_BOOL: nativeOps.execPairwiseTransformBool(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - null, (LongPointer) hostYShapeInfo, y, (LongPointer) yShapeInfo, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, extraArgs); break; default: nativeOps.execPairwiseTransform(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - null, (LongPointer) hostYShapeInfo, y, (LongPointer) yShapeInfo, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, extraArgs); break; } @@ -1485,32 +1391,32 @@ public class CudaExecutioner extends DefaultOpExecutioner { switch (op.getOpType()) { case TRANSFORM_ANY: nativeOps.execTransformAny(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, extraArgs); break; case TRANSFORM_FLOAT: nativeOps.execTransformFloat(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, extraArgs); break; case TRANSFORM_BOOL: nativeOps.execTransformBool(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, extraArgs); break; case TRANSFORM_SAME: nativeOps.execTransformSame(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, extraArgs); break; case TRANSFORM_STRICT: nativeOps.execTransformStrict(xShapeInfoHostPointer, op.opNum(), - null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, + x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo, + z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo, extraArgs); break; default: @@ -1521,8 +1427,6 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); - if (extraArgs != null) extraArgs.address(); @@ -1543,146 +1447,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public void exec(Batch batch) { - val surfaceBuffer = (BaseCudaDataBuffer) getBuffer(batch); - surfaceBuffer.lazyAllocateHostPointer(); - - val context = AtomicAllocator.getInstance().getDeviceContext(); - - val pointer = (IntPointer) new CudaPointer(AtomicAllocator.getInstance().getHostPointer(surfaceBuffer)) - .asIntPointer(); - val surfacePoint = AtomicAllocator.getInstance().getAllocationPoint(surfaceBuffer); - - int maxTypes = 5; - - int maxIntArrays = batch.getSample().maxIntArrays(); - - int maxArraySize = batch.getSample().maxIntArraySize(); - - - int indexPos = maxTypes * (Batch.getBatchLimit() * 16); - int intArraysPos = indexPos + (batch.getSample().maxIndexArguments() * (Batch.getBatchLimit() * 16)); - int realPos = (intArraysPos + (maxIntArrays * maxArraySize * (Batch.getBatchLimit() * 16))) - / (Nd4j.dataType() == DataType.DOUBLE ? 2 : 1); - - if (Nd4j.dataType() == DataType.HALF) - realPos *= 2; - - int argsPos = (realPos + (batch.getSample().maxRealArguments() * (Batch.getBatchLimit() * 16))) - / (Nd4j.dataType() == DataType.FLOAT ? 2 : 1); - - if (Nd4j.dataType() == DataType.HALF) - argsPos /= 4; - - int shapesPos = argsPos + (batch.getSample().maxArguments() * (Batch.getBatchLimit() * 16)); - DataType dataType = null; - for (int i = 0; i < batch.getNumAggregates(); i++) { - T op = batch.getAggregates().get(i); - - if (i == 0) - dataType = op.getArguments().get(0).dataType(); - - // put num arguments - int idx = i * maxTypes; - pointer.put(idx, op.getArguments().size()); - pointer.put(idx + 1, op.getShapes().size()); - pointer.put(idx + 2, op.getIndexingArguments().size()); - pointer.put(idx + 3, op.getRealArguments().size()); - pointer.put(idx + 4, op.getIntArrayArguments().size()); - - - // putting indexing arguments - for (int e = 0; e < op.getIndexingArguments().size(); e++) { - idx = indexPos + i * batch.getSample().maxIndexArguments(); - pointer.put(idx + e, op.getIndexingArguments().get(e)); - } - - // putting intArray values - int bsize = maxIntArrays * maxArraySize; - for (int e = 0; e < op.getIntArrayArguments().size(); e++) { - int step = (i * bsize) + (e * maxArraySize); - if (op.getIntArrayArguments().get(e) != null) - for (int x = 0; x < op.getIntArrayArguments().get(e).length; x++) { - idx = intArraysPos + step + x; - pointer.put(idx, op.getIntArrayArguments().get(e)[x]); - } - } - - // TODO: variable datatype should be handled here - // putting real arguments - switch (dataType) { - case FLOAT: { - FloatPointer realPtr = new FloatPointer(pointer); - for (int e = 0; e < op.getRealArguments().size(); e++) { - idx = realPos + i * op.maxRealArguments(); - realPtr.put(idx + e, op.getRealArguments().get(e).floatValue()); - } - } - break; - case DOUBLE: { - DoublePointer dPtr = new DoublePointer(pointer); - for (int e = 0; e < op.getRealArguments().size(); e++) { - idx = realPos + (i * op.maxRealArguments()); - dPtr.put(idx + e, op.getRealArguments().get(e).doubleValue()); - } - } - break; - case HALF: { - ShortPointer sPtr = new ShortPointer(pointer); - for (int e = 0; e < op.getRealArguments().size(); e++) { - idx = realPos + (i * op.maxRealArguments()); - sPtr.put(idx + e, BaseDataBuffer.fromFloat(op.getRealArguments().get(e).floatValue())); - } - } - break; - default: - throw new UnsupportedOperationException("Unknown data type"); - } - - // putting arguments pointers - PointerPointer ptrPtr = new PointerPointer(pointer); - for (int e = 0; e < op.getArguments().size(); e++) { - idx = argsPos + i * batch.getSample().maxArguments(); - - if (op.getArguments().get(e) != null) { - ptrPtr.put(idx + e, AtomicAllocator.getInstance().getPointer(op.getArguments().get(e), context)); - AtomicAllocator.getInstance().getAllocationPoint(op.getArguments().get(e)).tickDeviceWrite(); - } - } - - - // putting shape pointers - for (int e = 0; e < op.getShapes().size(); e++) { - idx = shapesPos + i * batch.getSample().maxShapes(); - - if (op.getShapes().get(e) != null) { - ptrPtr.put(idx + e, AtomicAllocator.getInstance().getPointer(op.getShapes().get(e), context)); - AtomicAllocator.getInstance().getAllocationPoint(op.getShapes().get(e)).tickDeviceWrite(); - } - } - } - - // trigger write, so getPointer request will force relocation to GPU - surfacePoint.tickHostWrite(); - - PointerPointer extraArgs = new PointerPointer(32); - extraArgs.put(0, null); - extraArgs.put(1, context.getOldStream()); - extraArgs.put(2, new CudaPointer(Math.min(batch.getNumAggregates(), - CudaEnvironment.getInstance().getConfiguration().getMaximumGridSize()))); - extraArgs.put(3, new CudaPointer(batch.getSample().getThreadsPerInstance())); - extraArgs.put(4, new CudaPointer(batch.getSample().getSharedMemorySize())); - - - nativeOps.execAggregateBatch(extraArgs, batch.getNumAggregates(), batch.opNum(), - batch.getSample().maxArguments(), batch.getSample().maxShapes(), - batch.getSample().maxIntArrays(), batch.getSample().maxIntArraySize(), - batch.getSample().maxIndexArguments(), batch.getSample().maxRealArguments(), - AtomicAllocator.getInstance().getPointer(surfaceBuffer, context), FlatBuffersMapper.getDataTypeAsByte(dataType)); - - if (nativeOps.lastErrorCode() != 0) - throw new RuntimeException(nativeOps.lastErrorMessage()); - - surfacePoint.tickHostWrite(); + throw new UnsupportedOperationException("Pew-pew"); } @Override @@ -1701,84 +1466,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public void exec(Aggregate op) { - int numArguments = op.getArguments().size(); - int numShapeArguments = op.getShapes().size(); - int numIndexArguments = op.getIndexingArguments().size(); - int numIntArrays = op.getIntArrayArguments().size(); - int numRealArguments = op.getRealArguments().size(); - - val context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext(); - - val extraArgs = new PointerPointer(32); - extraArgs.put(0, null); - extraArgs.put(1, context.getOldStream()); - extraArgs.put(2, new CudaPointer(1)); - extraArgs.put(3, new CudaPointer(op.getThreadsPerInstance())); - extraArgs.put(4, new CudaPointer(op.getSharedMemorySize())); - - long arguments[] = new long[numArguments]; - val dataType = op.getArguments().get(0).dataType(); - - for (int x = 0; x < numArguments; x++) { - arguments[x] = op.getArguments().get(x) == null ? 0 - : AtomicAllocator.getInstance().getPointer(op.getArguments().get(x), context).address(); - - if (op.getArguments().get(x) != null) - AtomicAllocator.getInstance().getAllocationPoint(op.getArguments().get(x)).tickDeviceWrite(); - } - - DataBuffer tempX = AllocationUtils.getPointersBuffer(arguments); - PointerPointer xPtr = new PointerPointer(AtomicAllocator.getInstance().getPointer(tempX, context)); - - - long shapes[] = new long[numShapeArguments]; - for (int x = 0; x < numShapeArguments; x++) { - shapes[x] = op.getShapes().get(x) == null ? 0 - : AtomicAllocator.getInstance().getPointer(op.getShapes().get(x), context).address(); - - if (op.getShapes().get(x) != null) - AtomicAllocator.getInstance().getAllocationPoint(op.getShapes().get(x)).tickDeviceWrite(); - } - - DataBuffer tempS = AllocationUtils.getPointersBuffer(shapes); - PointerPointer sPtr = new PointerPointer(AtomicAllocator.getInstance().getPointer(tempS, context)); - - - long ints[] = new long[numIntArrays]; - for (int x = 0; x < numIntArrays; x++) { - if (op.getIntArrayArguments().get(x) != null) { - DataBuffer intBuf = Nd4j.getDataBufferFactory().createInt(op.getIntArrayArguments().get(x)); - ints[x] = AtomicAllocator.getInstance().getPointer(intBuf, context).address(); - } - - } - - DataBuffer tempI = AllocationUtils.getPointersBuffer(ints); - PointerPointer iPtr = new PointerPointer(AtomicAllocator.getInstance().getPointer(tempI, context)); - - int[] indexes = new int[numIndexArguments]; - for (int x = 0; x < numIndexArguments; x++) { - indexes[x] = op.getIndexingArguments().get(x); - } - - DataBuffer intBuffer = Nd4j.getDataBufferFactory().createInt(indexes); - - double[] reals = new double[numRealArguments]; - INDArray realsBuffer; - for (int x = 0; x < numRealArguments; x++) { - reals[x] = op.getRealArguments().get(x).doubleValue(); - } - - realsBuffer = Nd4j.create(reals, new long[]{reals.length}, dataType); - - nativeOps.execAggregate(extraArgs, op.opNum(), xPtr, numArguments, sPtr, numShapeArguments, - (IntPointer) AtomicAllocator.getInstance().getPointer(intBuffer, context), - numIndexArguments, iPtr, numIntArrays, - AtomicAllocator.getInstance().getPointer(realsBuffer.data(), context), - numRealArguments, FlatBuffersMapper.getDataTypeAsByte(dataType)); - - if (nativeOps.lastErrorCode() != 0) - throw new RuntimeException(nativeOps.lastErrorMessage()); + throw new UnsupportedOperationException("Pew-pew"); } /** @@ -1810,7 +1498,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (CudaEnvironment.getInstance().getConfiguration().isDebug()) lastOp.set(op.opName()); - CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y()); + val context = AtomicAllocator.getInstance().getDeviceContext(); PointerPointer extraZZ = extraz.get().put(AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer()); @@ -1819,34 +1507,36 @@ public class CudaExecutioner extends DefaultOpExecutioner { val hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer()); val hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer()); + val x = op.x() == null ? null : ((BaseCudaDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val y = op.y() == null ? null : ((BaseCudaDataBuffer) op.y().data()).getOpaqueDataBuffer(); + val z = op.z() == null ? null : ((BaseCudaDataBuffer) op.z().data()).getOpaqueDataBuffer(); + if (op.x() != null && op.y() != null && op.z() != null) { // triple arg call nativeOps.execRandom3(extraZZ, op.opNum(), rng.getStatePointer(), // rng state ptr - null, (LongPointer) hostXShapeInfo, AtomicAllocator.getInstance().getPointer(op.x(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context), - null, (LongPointer) hostYShapeInfo, AtomicAllocator.getInstance().getPointer(op.y(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), - null, (LongPointer) hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), + x, (LongPointer) hostXShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context), + y, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), + z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.z().dataType()), context)); } else if (op.x() != null && op.z() != null) { //double arg call nativeOps.execRandom2(extraZZ, op.opNum(), rng.getStatePointer(), // rng state ptr - null, (LongPointer) hostXShapeInfo, AtomicAllocator.getInstance().getPointer(op.x(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context), - null, (LongPointer) hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), + x, (LongPointer) hostXShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context), + z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.z().dataType()),context)); } else { // single arg call nativeOps.execRandom(extraZZ, op.opNum(), rng.getStatePointer(), // rng state ptr - null, (LongPointer) hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), + z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.z().dataType()), context)); } if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - AtomicAllocator.getInstance().getFlowController().registerAction(context, op.z(), op.x(), op.y()); - profilingConfigurableHookOut(op, st); return op.z(); @@ -1944,6 +1634,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { val extras = extraz.get().put(1, context.getOldStream()); + ((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer().syncToSpecial(); NativeOpsHolder.getInstance().getDeviceNativeOps().encodeThresholdP1(extras, @@ -1967,22 +1658,16 @@ public class CudaExecutioner extends DefaultOpExecutioner { blocksBuffer.put(0, numMatches); } -/* - log.info("Totals: {}", numMatches); - - - log.info("Number of blocks for compression: {}", numBlocks); - log.info("BlocksCounts: {}", Arrays.toString(blocksBuffer.asInt())); -*/ DataBuffer encodedBuffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createInt(4+numMatches, false) : Nd4j.getDataBufferFactory().createInt(4+numMatches, false, Nd4j.getMemoryManager().getCurrentWorkspace()); - AtomicAllocator.getInstance().getAllocationPoint(encodedBuffer).tickHostWrite(); + encodedBuffer.put(0, numMatches); encodedBuffer.put(1, (int) buffer.length()); encodedBuffer.put(2, Float.floatToIntBits((float) threshold)); - AtomicAllocator.getInstance().getAllocationPoint(encodedBuffer).tickHostWrite(); encodedBuffer.put(3, ThresholdCompression.FLEXIBLE_ENCODING); + ((BaseCudaDataBuffer) encodedBuffer).getOpaqueDataBuffer().syncToSpecial(); + int prefixThreads = 512; int numElts = numBlocks; @@ -2095,7 +1780,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { // format id buffer.put(3, ThresholdCompression.BITMAP_ENCODING); - CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(indArray); + val context = AtomicAllocator.getInstance().getDeviceContext(); if (extraz.get() == null) extraz.set(new PointerPointer(32)); @@ -2108,17 +1793,20 @@ public class CudaExecutioner extends DefaultOpExecutioner { context.getBufferReduction() ); + + val src = AtomicAllocator.getInstance().getPointer(indArray, context); + val dst = (IntPointer) AtomicAllocator.getInstance().getPointer(buffer, context); + ((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer().syncToSpecial(); + long val = nativeOps.encodeBitmap(extras, - AtomicAllocator.getInstance().getPointer(indArray, context), (LongPointer) AtomicAllocator.getInstance().getHostPointer(indArray.shapeInfoDataBuffer()), + src, (LongPointer) AtomicAllocator.getInstance().getHostPointer(indArray.shapeInfoDataBuffer()), length, - (IntPointer) AtomicAllocator.getInstance().getPointer(buffer, context), + dst, (float) threshold); if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - AtomicAllocator.getInstance().getFlowController().registerAction(context, indArray); - AtomicAllocator.getInstance().getAllocationPoint(buffer).tickDeviceWrite(); return val; @@ -2127,7 +1815,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public INDArray bitmapDecode(INDArray encoded, INDArray target) { - CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(target); + val context = AtomicAllocator.getInstance().getDeviceContext(); if (extraz.get() == null) extraz.set(new PointerPointer(32)); @@ -2144,8 +1832,6 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - AtomicAllocator.getInstance().getFlowController().registerAction(context, target); - return target; } @@ -2220,15 +1906,15 @@ public class CudaExecutioner extends DefaultOpExecutioner { return Collections.emptyList(); } - val inputBuffers = new PointerPointer<>(op.inputArguments().length * 2); - val inputShapes = new PointerPointer<>(op.inputArguments().length); + val inputBuffers = new PointerPointer<>(op.inputArguments().size() * 2); + val inputShapes = new PointerPointer<>(op.inputArguments().size()); int cnt= 0; for (val in: op.inputArguments()) { // NOT A TYPO: shape functions work on host side only if (!in.isEmpty()) { inputBuffers.put(cnt, in.data().addressPointer()); - inputBuffers.put(cnt + op.inputArguments().length, AtomicAllocator.getInstance().getPointer(in.data())); + inputBuffers.put(cnt + op.inputArguments().size(), AtomicAllocator.getInstance().getPointer(in.data())); } inputShapes.put(cnt++, in.shapeInfoDataBuffer().addressPointer()); @@ -2253,7 +1939,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { for (val t: op.tArgs()) tArgs.put(cnt++, t); - OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().length, tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments()); + OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().size(), tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments()); if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); @@ -2324,6 +2010,17 @@ public class CudaExecutioner extends DefaultOpExecutioner { val result = exec(op, context); val states = context.getRngStates(); + // check if input && output needs update + for (val in:op.inputArguments()) { + if (!in.isEmpty()) + ((BaseCudaDataBuffer) in.data()).actualizePointerAndIndexer(); + } + + for (val out:op.outputArguments()) { + if (!out.isEmpty()) + ((BaseCudaDataBuffer) out.data()).actualizePointerAndIndexer(); + } + // pulling states back Nd4j.getRandom().setStates(states.getFirst(), states.getSecond()); @@ -2407,7 +2104,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { val array = Nd4j.create(shapeOf, stridesOf, 0, order); Pointer.memcpy(AtomicAllocator.getInstance().getHostPointer(array), buffer, ArrayUtil.prod(shapeOf) * Nd4j.sizeOfDataType()); - AtomicAllocator.getInstance().getAllocationPoint(array).tickHostWrite(); + //AtomicAllocator.getInstance().getAllocationPoint(array).tickHostWrite(); + if (1 > 0) + throw new UnsupportedOperationException("Pew-pew"); String nodeName = nativeOps.getVariableName(var); newMap.put(nodeName, array); @@ -2463,7 +2162,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { } @Override - public String getString(Utf8Buffer buffer, long index) { + public String getString(DataBuffer buffer, long index) { + Preconditions.checkArgument(buffer instanceof CudaUtf8Buffer, "Expected Utf8Buffer"); + val addr = ((LongIndexer) buffer.indexer()).get(index); val ptr = new PagedPointer(addr); val str = new Nd4jCuda.utf8string(ptr); @@ -2477,7 +2178,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public void scatterUpdate(ScatterUpdate.UpdateOp op, @NonNull INDArray array, @NonNull INDArray indices, @NonNull INDArray updates, @NonNull int[] axis) { - CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(array, indices, updates); + val context = AtomicAllocator.getInstance().getDeviceContext(); val tadX = tadManager.getTADOnlyShapeInfo(array, axis); val tadY = tadManager.getTADOnlyShapeInfo(updates, axis); @@ -2497,8 +2198,6 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); - - AtomicAllocator.getInstance().getFlowController().registerAction(context, array, indices, updates); } @Override @@ -2520,13 +2219,6 @@ public class CudaExecutioner extends DefaultOpExecutioner { if (status != 0) throw new RuntimeException("Op [" + op.opName() + "] execution failed"); - - - for (val arr:op.outputArguments()) - AtomicAllocator.getInstance().registerAction(ctx, arr); - - AtomicAllocator.getInstance().registerAction(ctx, null, op.inputArguments()); - profilingConfigurableHookOut(op, st); if (context.getOutputArrays().isEmpty()) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java index d37a0184d..6f37be02a 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java @@ -29,6 +29,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseOpContext; import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer; import org.nd4j.linalg.jcublas.context.CudaContext; import org.nd4j.linalg.primitives.Pair; import org.nd4j.nativeblas.NativeOps; @@ -88,8 +89,7 @@ public class CudaOpContext extends BaseOpContext implements OpContext { @Override public void setInputArray(int index, @NonNull INDArray array) { val ctx = AtomicAllocator.getInstance().getFlowController().prepareAction(null, array); - - nativeOps.setGraphContextInputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), array.isEmpty() ? null : AtomicAllocator.getInstance().getPointer(array, ctx), AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer())); + nativeOps.setGraphContextInputBuffer(context, index, array.isEmpty() ? null : ((BaseCudaDataBuffer) array.data()).getOpaqueDataBuffer(), array.shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer())); super.setInputArray(index, array); } @@ -97,33 +97,13 @@ public class CudaOpContext extends BaseOpContext implements OpContext { @Override public void setOutputArray(int index, @NonNull INDArray array) { val ctx = AtomicAllocator.getInstance().getFlowController().prepareAction(array, null); - - nativeOps.setGraphContextOutputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), array.isEmpty() ? null : AtomicAllocator.getInstance().getPointer(array, ctx), AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer())); + nativeOps.setGraphContextOutputBuffer(context, index, array.isEmpty() ? null : ((BaseCudaDataBuffer) array.data()).getOpaqueDataBuffer(), array.shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer())); super.setOutputArray(index, array); } @Override public Pointer contextPointer() { - for (val v:fastpath_in.values()) { - if (v.isEmpty() || v.isS()) - continue; - - AtomicAllocator.getInstance().getAllocationPoint(v).tickHostRead(); - AtomicAllocator.getInstance().getAllocationPoint(v).tickDeviceRead(); - - //if (context.isInplace()) - //AtomicAllocator.getInstance().getAllocationPoint(v).tickDeviceWrite(); - } - - for (val v:fastpath_out.values()) { - if (v.isEmpty() || v.isS()) - continue; - - AtomicAllocator.getInstance().getAllocationPoint(v).tickHostRead(); - AtomicAllocator.getInstance().getAllocationPoint(v).tickDeviceRead(); - } - return context; } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index 6dab1ab01..5bfd0bf44 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -175,6 +175,74 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper { } } +@Name("std::vector") public static class ConstNDArrayVector extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public ConstNDArrayVector(Pointer p) { super(p); } + public ConstNDArrayVector(NDArray value) { this(1); put(0, value); } + public ConstNDArrayVector(NDArray ... array) { this(array.length); put(array); } + public ConstNDArrayVector() { allocate(); } + public ConstNDArrayVector(long n) { allocate(n); } + private native void allocate(); + private native void allocate(@Cast("size_t") long n); + public native @Name("operator=") @ByRef ConstNDArrayVector put(@ByRef ConstNDArrayVector x); + + public boolean empty() { return size() == 0; } + public native long size(); + public void clear() { resize(0); } + public native void resize(@Cast("size_t") long n); + + @Index(function = "at") public native @Const NDArray get(@Cast("size_t") long i); + public native ConstNDArrayVector put(@Cast("size_t") long i, NDArray value); + + public native @ByVal Iterator insert(@ByVal Iterator pos, @Const NDArray value); + public native @ByVal Iterator erase(@ByVal Iterator pos); + public native @ByVal Iterator begin(); + public native @ByVal Iterator end(); + @NoOffset @Name("iterator") public static class Iterator extends Pointer { + public Iterator(Pointer p) { super(p); } + public Iterator() { } + + public native @Name("operator++") @ByRef Iterator increment(); + public native @Name("operator==") boolean equals(@ByRef Iterator it); + public native @Name("operator*") @Const NDArray get(); + } + + public NDArray[] get() { + NDArray[] array = new NDArray[size() < Integer.MAX_VALUE ? (int)size() : Integer.MAX_VALUE]; + for (int i = 0; i < array.length; i++) { + array[i] = get(i); + } + return array; + } + @Override public String toString() { + return java.util.Arrays.toString(get()); + } + + public NDArray pop_back() { + long size = size(); + NDArray value = get(size - 1); + resize(size - 1); + return value; + } + public ConstNDArrayVector push_back(NDArray value) { + long size = size(); + resize(size + 1); + return put(size, value); + } + public ConstNDArrayVector put(NDArray value) { + if (size() != 1) { resize(1); } + return put(0, value); + } + public ConstNDArrayVector put(NDArray ... array) { + if (size() != array.length) { resize(array.length); } + for (int i = 0; i < array.length; i++) { + put(i, array[i]); + } + return this; + } +} + @NoOffset @Name("std::pair") public static class IntIntPair extends Pointer { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ @@ -240,12 +308,167 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper { QINT16 = 16, BFLOAT16 = 17, UTF8 = 50, + UTF16 = 51, + UTF32 = 52, ANY = 100, AUTO = 200; // #endif +// Parsed from array/DataBuffer.h + +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +// #ifndef DEV_TESTS_DATABUFFER_H +// #define DEV_TESTS_DATABUFFER_H + +// #include +// #include +// #include +// #include +// #include +// #include +// #include + +@Namespace("nd4j") @NoOffset public static class DataBuffer extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public DataBuffer(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public DataBuffer(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public DataBuffer position(long position) { + return (DataBuffer)super.position(position); + } + + + public DataBuffer(Pointer primary, Pointer special, + @Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType, + @Cast("const bool") boolean isOwnerPrimary/*=false*/, @Cast("const bool") boolean isOwnerSpecial/*=false*/, + Workspace workspace/*=nullptr*/) { super((Pointer)null); allocate(primary, special, lenInBytes, dataType, isOwnerPrimary, isOwnerSpecial, workspace); } + private native void allocate(Pointer primary, Pointer special, + @Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType, + @Cast("const bool") boolean isOwnerPrimary/*=false*/, @Cast("const bool") boolean isOwnerSpecial/*=false*/, + Workspace workspace/*=nullptr*/); + public DataBuffer(Pointer primary, Pointer special, + @Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType) { super((Pointer)null); allocate(primary, special, lenInBytes, dataType); } + private native void allocate(Pointer primary, Pointer special, + @Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType); + + public DataBuffer(Pointer primary, + @Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType, + @Cast("const bool") boolean isOwnerPrimary/*=false*/, + Workspace workspace/*=nullptr*/) { super((Pointer)null); allocate(primary, lenInBytes, dataType, isOwnerPrimary, workspace); } + private native void allocate(Pointer primary, + @Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType, + @Cast("const bool") boolean isOwnerPrimary/*=false*/, + Workspace workspace/*=nullptr*/); + public DataBuffer(Pointer primary, + @Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType) { super((Pointer)null); allocate(primary, lenInBytes, dataType); } + private native void allocate(Pointer primary, + @Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType); + + public DataBuffer(@Const Pointer hostBuffer, + @Cast("const nd4j::DataType") int dataType, @Cast("const size_t") long lenInBytes, + Workspace workspace/*=nullptr*/) { super((Pointer)null); allocate(hostBuffer, dataType, lenInBytes, workspace); } + private native void allocate(@Const Pointer hostBuffer, + @Cast("const nd4j::DataType") int dataType, @Cast("const size_t") long lenInBytes, + Workspace workspace/*=nullptr*/); + public DataBuffer(@Const Pointer hostBuffer, + @Cast("const nd4j::DataType") int dataType, @Cast("const size_t") long lenInBytes) { super((Pointer)null); allocate(hostBuffer, dataType, lenInBytes); } + private native void allocate(@Const Pointer hostBuffer, + @Cast("const nd4j::DataType") int dataType, @Cast("const size_t") long lenInBytes); + + public DataBuffer(@Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType, Workspace workspace/*=nullptr*/, @Cast("const bool") boolean allocBoth/*=false*/) { super((Pointer)null); allocate(lenInBytes, dataType, workspace, allocBoth); } + private native void allocate(@Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType, Workspace workspace/*=nullptr*/, @Cast("const bool") boolean allocBoth/*=false*/); + public DataBuffer(@Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType) { super((Pointer)null); allocate(lenInBytes, dataType); } + private native void allocate(@Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType); + + public DataBuffer(@Const @ByRef DataBuffer other) { super((Pointer)null); allocate(other); } + private native void allocate(@Const @ByRef DataBuffer other); + public DataBuffer() { super((Pointer)null); allocate(); } + private native void allocate(); + + public native @ByRef @Name("operator =") DataBuffer put(@Const @ByRef DataBuffer other); + + public native @Cast("nd4j::DataType") int getDataType(); + public native void setDataType(@Cast("nd4j::DataType") int dataType); + public native @Cast("size_t") long getLenInBytes(); + + public native Pointer primary(); + public native Pointer special(); + + public native void allocatePrimary(); + public native void allocateSpecial(); + + public native void writePrimary(); + public native void writeSpecial(); + public native void readPrimary(); + public native void readSpecial(); + public native @Cast("bool") boolean isPrimaryActual(); + public native @Cast("bool") boolean isSpecialActual(); + + public native void expand(@Cast("const uint64_t") long size); + + public native int deviceId(); + public native void setDeviceId(int deviceId); + public native void migrate(); + + public native void syncToPrimary(@Const LaunchContext context, @Cast("const bool") boolean forceSync/*=false*/); + public native void syncToPrimary(@Const LaunchContext context); + public native void syncToSpecial(@Cast("const bool") boolean forceSync/*=false*/); + public native void syncToSpecial(); + + public native void setToZeroBuffers(@Cast("const bool") boolean both/*=false*/); + public native void setToZeroBuffers(); + + public native void copyBufferFrom(@Const @ByRef DataBuffer other, @Cast("size_t") long sizeToCopyinBytes/*=0*/, @Cast("const Nd4jLong") long offsetThis/*=0*/, @Cast("const Nd4jLong") long offsetOther/*=0*/); + public native void copyBufferFrom(@Const @ByRef DataBuffer other); + + public static native void memcpy(@Const @ByRef DataBuffer dst, @Const @ByRef DataBuffer src); + + public native void setPrimaryBuffer(Pointer buffer, @Cast("size_t") long length); + public native void setSpecialBuffer(Pointer buffer, @Cast("size_t") long length); + + /** + * This method deletes buffers, if we're owners + */ + public native @Name("close") void _close(); +} +///// IMLEMENTATION OF INLINE METHODS ///// + +//////////////////////////////////////////////////////////////////////// + + +//////////////////////////////////////////////////////////////////////// + + + + + +// #endif //DEV_TESTS_DATABUFFER_H + + // Parsed from array/ConstantDescriptor.h /******************************************************************************* @@ -753,6 +976,7 @@ bool verbose = false; // #include // #include // #include +// #include // #include // #include // #include @@ -801,25 +1025,19 @@ public native void setTADThreshold(int num); */ public native void execIndexReduceScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo); public native void execIndexReduceScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); public native void execIndexReduceScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo); /** * @@ -834,31 +1052,22 @@ public native void execIndexReduceScalar(@Cast("Nd4jPointer*") PointerPointer ex */ public native void execIndexReduce(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape); public native void execIndexReduce(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape); public native void execIndexReduce(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape); /** * @@ -875,74 +1084,50 @@ public native void execIndexReduce(@Cast("Nd4jPointer*") PointerPointer extraPoi public native void execBroadcast( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape); + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, @Cast("Nd4jLong*") LongPointer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape); public native void execBroadcast( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape); + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape); public native void execBroadcast( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape); + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeInfo, @Cast("Nd4jLong*") long[] dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape); public native void execBroadcastBool( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, @Cast("Nd4jLong*") LongPointer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape); + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape); public native void execBroadcastBool( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape); + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape); public native void execBroadcastBool( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeInfo, @Cast("Nd4jLong*") long[] dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape); + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape); /** * @@ -959,63 +1144,45 @@ public native void execBroadcastBool( public native void execPairwiseTransform( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, @Cast("Nd4jLong*") LongPointer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); public native void execPairwiseTransform( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); public native void execPairwiseTransform( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeInfo, @Cast("Nd4jLong*") long[] dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); public native void execPairwiseTransformBool( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, @Cast("Nd4jLong*") LongPointer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); public native void execPairwiseTransformBool( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); public native void execPairwiseTransformBool( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeInfo, @Cast("Nd4jLong*") long[] dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); /** @@ -1029,92 +1196,68 @@ public native void execPairwiseTransformBool( */ public native void execReduceFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo); public native void execReduceFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); public native void execReduceFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo); public native void execReduceSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo); public native void execReduceSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); public native void execReduceSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo); public native void execReduceBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo); public native void execReduceBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); public native void execReduceBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo); public native void execReduceLong(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo); public native void execReduceLong(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); public native void execReduceLong(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo); /** * @@ -1127,118 +1270,82 @@ public native void execReduceLong(@Cast("Nd4jPointer*") PointerPointer extraPoin */ public native void execReduceFloat2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape); public native void execReduceFloat2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape); public native void execReduceFloat2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape); public native void execReduceSame2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape); public native void execReduceSame2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape); public native void execReduceSame2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape); public native void execReduceBool2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape); public native void execReduceBool2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape); public native void execReduceBool2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape); public native void execReduceLong2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape); public native void execReduceLong2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape); public native void execReduceLong2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape); /** * @@ -1253,31 +1360,22 @@ public native void execReduceLong2(@Cast("Nd4jPointer*") PointerPointer extraPoi */ public native void execReduce3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo); + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, @Cast("Nd4jLong*") LongPointer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo); public native void execReduce3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); public native void execReduce3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo); + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeInfo, @Cast("Nd4jLong*") long[] dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo); /** * @@ -1290,31 +1388,22 @@ public native void execReduce3(@Cast("Nd4jPointer*") PointerPointer extraPointer */ public native void execReduce3Scalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo); + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, @Cast("Nd4jLong*") LongPointer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo); public native void execReduce3Scalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); public native void execReduce3Scalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo); + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeInfo, @Cast("Nd4jLong*") long[] dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo); /** * * @param opNum @@ -1330,82 +1419,58 @@ public native void execReduce3Scalar(@Cast("Nd4jPointer*") PointerPointer extraP */ public native void execReduce3Tad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, @Cast("Nd4jLong*") LongPointer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape, @Cast("Nd4jLong*") LongPointer tadOnlyShapeInfo, @Cast("Nd4jLong*") LongPointer tadOffsets, @Cast("Nd4jLong*") LongPointer yTadOnlyShapeInfo, @Cast("Nd4jLong*") LongPointer yTadOffsets); public native void execReduce3Tad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape, @Cast("Nd4jLong*") LongBuffer tadOnlyShapeInfo, @Cast("Nd4jLong*") LongBuffer tadOffsets, @Cast("Nd4jLong*") LongBuffer yTadOnlyShapeInfo, @Cast("Nd4jLong*") LongBuffer yTadOffsets); public native void execReduce3Tad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeInfo, @Cast("Nd4jLong*") long[] dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape, @Cast("Nd4jLong*") long[] tadOnlyShapeInfo, @Cast("Nd4jLong*") long[] tadOffsets, @Cast("Nd4jLong*") long[] yTadOnlyShapeInfo, @Cast("Nd4jLong*") long[] yTadOffsets); public native void execReduce3All(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, @Cast("Nd4jLong*") LongPointer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape, @Cast("Nd4jLong*") LongPointer xTadShapeInfo, @Cast("Nd4jLong*") LongPointer xOffsets, @Cast("Nd4jLong*") LongPointer yTadShapeInfo, @Cast("Nd4jLong*") LongPointer yOffsets); public native void execReduce3All(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape, @Cast("Nd4jLong*") LongBuffer xTadShapeInfo, @Cast("Nd4jLong*") LongBuffer xOffsets, @Cast("Nd4jLong*") LongBuffer yTadShapeInfo, @Cast("Nd4jLong*") LongBuffer yOffsets); public native void execReduce3All(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeInfo, @Cast("Nd4jLong*") long[] dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape, @Cast("Nd4jLong*") long[] xTadShapeInfo, @Cast("Nd4jLong*") long[] xOffsets, @Cast("Nd4jLong*") long[] yTadShapeInfo, @Cast("Nd4jLong*") long[] yOffsets); @@ -1422,58 +1487,40 @@ public native void execReduce3All(@Cast("Nd4jPointer*") PointerPointer extraPoin */ public native void execScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hScalar, @Cast("Nd4jLong*") LongPointer hSscalarShapeInfo, - Pointer dScalar, @Cast("Nd4jLong*") LongPointer dSscalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbScalar, @Cast("Nd4jLong*") LongPointer hSscalarShapeInfo, @Cast("Nd4jLong*") LongPointer dSscalarShapeInfo, Pointer extraParams); public native void execScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hScalar, @Cast("Nd4jLong*") LongBuffer hSscalarShapeInfo, - Pointer dScalar, @Cast("Nd4jLong*") LongBuffer dSscalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbScalar, @Cast("Nd4jLong*") LongBuffer hSscalarShapeInfo, @Cast("Nd4jLong*") LongBuffer dSscalarShapeInfo, Pointer extraParams); public native void execScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hScalar, @Cast("Nd4jLong*") long[] hSscalarShapeInfo, - Pointer dScalar, @Cast("Nd4jLong*") long[] dSscalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbScalar, @Cast("Nd4jLong*") long[] hSscalarShapeInfo, @Cast("Nd4jLong*") long[] dSscalarShapeInfo, Pointer extraParams); public native void execScalarBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hScalar, @Cast("Nd4jLong*") LongPointer hSscalarShapeInfo, - Pointer dScalar, @Cast("Nd4jLong*") LongPointer dSscalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbScalar, @Cast("Nd4jLong*") LongPointer hSscalarShapeInfo, @Cast("Nd4jLong*") LongPointer dSscalarShapeInfo, Pointer extraParams); public native void execScalarBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hScalar, @Cast("Nd4jLong*") LongBuffer hSscalarShapeInfo, - Pointer dScalar, @Cast("Nd4jLong*") LongBuffer dSscalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbScalar, @Cast("Nd4jLong*") LongBuffer hSscalarShapeInfo, @Cast("Nd4jLong*") LongBuffer dSscalarShapeInfo, Pointer extraParams); public native void execScalarBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hScalar, @Cast("Nd4jLong*") long[] hSscalarShapeInfo, - Pointer dScalar, @Cast("Nd4jLong*") long[] dSscalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbScalar, @Cast("Nd4jLong*") long[] hSscalarShapeInfo, @Cast("Nd4jLong*") long[] dSscalarShapeInfo, Pointer extraParams); /** @@ -1485,27 +1532,21 @@ public native void execScalarBool(@Cast("Nd4jPointer*") PointerPointer extraPoin */ public native void execSummaryStatsScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, @Cast("bool") boolean biasCorrected); public native void execSummaryStatsScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, @Cast("bool") boolean biasCorrected); public native void execSummaryStatsScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, @Cast("bool") boolean biasCorrected); /** * @@ -1518,27 +1559,21 @@ public native void execSummaryStatsScalar(@Cast("Nd4jPointer*") PointerPointer e */ public native void execSummaryStats(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, @Cast("bool") boolean biasCorrected); public native void execSummaryStats(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, @Cast("bool") boolean biasCorrected); public native void execSummaryStats(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, @Cast("bool") boolean biasCorrected); /** * @@ -1553,35 +1588,26 @@ public native void execSummaryStats(@Cast("Nd4jPointer*") PointerPointer extraPo */ public native void execSummaryStatsTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape, @Cast("bool") boolean biasCorrected, @Cast("Nd4jLong*") LongPointer tadShapeInfo, @Cast("Nd4jLong*") LongPointer tadOffsets); public native void execSummaryStatsTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape, @Cast("bool") boolean biasCorrected, @Cast("Nd4jLong*") LongBuffer tadShapeInfo, @Cast("Nd4jLong*") LongBuffer tadOffsets); public native void execSummaryStatsTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape, @Cast("bool") boolean biasCorrected, @Cast("Nd4jLong*") long[] tadShapeInfo, @Cast("Nd4jLong*") long[] tadOffsets); @@ -1597,112 +1623,82 @@ public native void execSummaryStatsTad(@Cast("Nd4jPointer*") PointerPointer extr */ public native void execTransformFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); public native void execTransformFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); public native void execTransformFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); public native void execTransformSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); public native void execTransformSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); public native void execTransformSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); public native void execTransformBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); public native void execTransformBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); public native void execTransformBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); public native void execTransformAny(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); public native void execTransformAny(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); public native void execTransformAny(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); public native void execTransformStrict(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); public native void execTransformStrict(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); public native void execTransformStrict(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); /** @@ -1720,81 +1716,57 @@ public native void execTransformStrict(@Cast("Nd4jPointer*") PointerPointer extr */ public native void execScalarTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hScalars, @Cast("Nd4jLong*") LongPointer hScalarShapeInfo, - Pointer dScalars, @Cast("Nd4jLong*") LongPointer dScalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbScalars, @Cast("Nd4jLong*") LongPointer hScalarShapeInfo, @Cast("Nd4jLong*") LongPointer dScalarShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape, @Cast("Nd4jLong*") LongPointer tadShapeInfo, @Cast("Nd4jLong*") LongPointer tadOffsets, @Cast("Nd4jLong*") LongPointer tadShapeInfoZ, @Cast("Nd4jLong*") LongPointer tadOffsetsZ); public native void execScalarTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hScalars, @Cast("Nd4jLong*") LongBuffer hScalarShapeInfo, - Pointer dScalars, @Cast("Nd4jLong*") LongBuffer dScalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbScalars, @Cast("Nd4jLong*") LongBuffer hScalarShapeInfo, @Cast("Nd4jLong*") LongBuffer dScalarShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape, @Cast("Nd4jLong*") LongBuffer tadShapeInfo, @Cast("Nd4jLong*") LongBuffer tadOffsets, @Cast("Nd4jLong*") LongBuffer tadShapeInfoZ, @Cast("Nd4jLong*") LongBuffer tadOffsetsZ); public native void execScalarTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hScalars, @Cast("Nd4jLong*") long[] hScalarShapeInfo, - Pointer dScalars, @Cast("Nd4jLong*") long[] dScalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbScalars, @Cast("Nd4jLong*") long[] hScalarShapeInfo, @Cast("Nd4jLong*") long[] dScalarShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape, @Cast("Nd4jLong*") long[] tadShapeInfo, @Cast("Nd4jLong*") long[] tadOffsets, @Cast("Nd4jLong*") long[] tadShapeInfoZ, @Cast("Nd4jLong*") long[] tadOffsetsZ); public native void execScalarBoolTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hScalars, @Cast("Nd4jLong*") LongPointer hScalarShapeInfo, - Pointer dScalars, @Cast("Nd4jLong*") LongPointer dScalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbScalars, @Cast("Nd4jLong*") LongPointer hScalarShapeInfo, @Cast("Nd4jLong*") LongPointer dScalarShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape, @Cast("Nd4jLong*") LongPointer tadShapeInfo, @Cast("Nd4jLong*") LongPointer tadOffsets, @Cast("Nd4jLong*") LongPointer tadShapeInfoZ, @Cast("Nd4jLong*") LongPointer tadOffsetsZ); public native void execScalarBoolTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hScalars, @Cast("Nd4jLong*") LongBuffer hScalarShapeInfo, - Pointer dScalars, @Cast("Nd4jLong*") LongBuffer dScalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbScalars, @Cast("Nd4jLong*") LongBuffer hScalarShapeInfo, @Cast("Nd4jLong*") LongBuffer dScalarShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape, @Cast("Nd4jLong*") LongBuffer tadShapeInfo, @Cast("Nd4jLong*") LongBuffer tadOffsets, @Cast("Nd4jLong*") LongBuffer tadShapeInfoZ, @Cast("Nd4jLong*") LongBuffer tadOffsetsZ); public native void execScalarBoolTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hScalars, @Cast("Nd4jLong*") long[] hScalarShapeInfo, - Pointer dScalars, @Cast("Nd4jLong*") long[] dScalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbScalars, @Cast("Nd4jLong*") long[] hScalarShapeInfo, @Cast("Nd4jLong*") long[] dScalarShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape, @Cast("Nd4jLong*") long[] tadShapeInfo, @Cast("Nd4jLong*") long[] tadOffsets, @Cast("Nd4jLong*") long[] tadShapeInfoZ, @Cast("Nd4jLong*") long[] tadOffsetsZ); @@ -2157,10 +2129,8 @@ public native void deleteTadPack(OpaqueTadPack ptr); * @param zTadOffsets */ public native void pullRows(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo, - Pointer z, @Cast("Nd4jLong*") LongPointer zShapeInfo, - Pointer dz, @Cast("Nd4jLong*") LongPointer dzShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer xShapeInfo, @Cast("Nd4jLong*") LongPointer dxShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer zShapeInfo, @Cast("Nd4jLong*") LongPointer dzShapeInfo, @Cast("Nd4jLong") long n, @Cast("Nd4jLong*") LongPointer indexes, @Cast("Nd4jLong*") LongPointer tadShapeInfo, @@ -2168,10 +2138,8 @@ public native void pullRows(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongPointer zTadShapeInfo, @Cast("Nd4jLong*") LongPointer zTadOffsets); public native void pullRows(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo, - Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo, - Pointer z, @Cast("Nd4jLong*") LongBuffer zShapeInfo, - Pointer dz, @Cast("Nd4jLong*") LongBuffer dzShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer xShapeInfo, @Cast("Nd4jLong*") LongBuffer dxShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer zShapeInfo, @Cast("Nd4jLong*") LongBuffer dzShapeInfo, @Cast("Nd4jLong") long n, @Cast("Nd4jLong*") LongBuffer indexes, @Cast("Nd4jLong*") LongBuffer tadShapeInfo, @@ -2179,10 +2147,8 @@ public native void pullRows(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer zTadShapeInfo, @Cast("Nd4jLong*") LongBuffer zTadOffsets); public native void pullRows(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo, - Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo, - Pointer z, @Cast("Nd4jLong*") long[] zShapeInfo, - Pointer dz, @Cast("Nd4jLong*") long[] dzShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] xShapeInfo, @Cast("Nd4jLong*") long[] dxShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] zShapeInfo, @Cast("Nd4jLong*") long[] dzShapeInfo, @Cast("Nd4jLong") long n, @Cast("Nd4jLong*") long[] indexes, @Cast("Nd4jLong*") long[] tadShapeInfo, @@ -2448,20 +2414,17 @@ public native void execAggregateBatch(@Cast("Nd4jPointer*") PointerPointer extra public native void execRandom(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeBuffer, @Cast("Nd4jLong*") LongPointer dZShapeBuffer, Pointer extraArguments); public native void execRandom(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeBuffer, @Cast("Nd4jLong*") LongBuffer dZShapeBuffer, Pointer extraArguments); public native void execRandom(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeBuffer, @Cast("Nd4jLong*") long[] dZShapeBuffer, Pointer extraArguments); /** @@ -2480,32 +2443,23 @@ public native void execRandom(@Cast("Nd4jPointer*") PointerPointer extraPointers public native void execRandom3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeBuffer, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeBuffer, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeBuffer, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeBuffer, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeBuffer, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeBuffer, @Cast("Nd4jLong*") LongPointer dXShapeBuffer, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeBuffer, @Cast("Nd4jLong*") LongPointer dYShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeBuffer, @Cast("Nd4jLong*") LongPointer dZShapeBuffer, Pointer extraArguments); public native void execRandom3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeBuffer, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeBuffer, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeBuffer, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeBuffer, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeBuffer, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeBuffer, @Cast("Nd4jLong*") LongBuffer dXShapeBuffer, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeBuffer, @Cast("Nd4jLong*") LongBuffer dYShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeBuffer, @Cast("Nd4jLong*") LongBuffer dZShapeBuffer, Pointer extraArguments); public native void execRandom3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeBuffer, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeBuffer, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeBuffer, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeBuffer, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeBuffer, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeBuffer, @Cast("Nd4jLong*") long[] dXShapeBuffer, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeBuffer, @Cast("Nd4jLong*") long[] dYShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeBuffer, @Cast("Nd4jLong*") long[] dZShapeBuffer, Pointer extraArguments); /** @@ -2522,26 +2476,20 @@ public native void execRandom3(@Cast("Nd4jPointer*") PointerPointer extraPointer public native void execRandom2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeBuffer, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeBuffer, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeBuffer, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeBuffer, @Cast("Nd4jLong*") LongPointer dXShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeBuffer, @Cast("Nd4jLong*") LongPointer dZShapeBuffer, Pointer extraArguments); public native void execRandom2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeBuffer, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeBuffer, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeBuffer, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeBuffer, @Cast("Nd4jLong*") LongBuffer dXShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeBuffer, @Cast("Nd4jLong*") LongBuffer dZShapeBuffer, Pointer extraArguments); public native void execRandom2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeBuffer, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeBuffer, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeBuffer, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeBuffer, @Cast("Nd4jLong*") long[] dXShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeBuffer, @Cast("Nd4jLong*") long[] dZShapeBuffer, Pointer extraArguments); @@ -2584,52 +2532,6 @@ public native void reSeedBuffer(@Cast("Nd4jPointer*") PointerPointer extraPointe */ public native void destroyRandom(@Cast("Nd4jPointer") Pointer ptrRandom); -/** - * Grid operations - */ - - - - -/** - * - * @param extras - * @param opTypeA - * @param opNumA - * @param opTypeB - * @param opNumB - * @param N - * @param dx - * @param xShapeInfo - * @param dy - * @param yShapeInfo - * @param dz - * @param zShapeInfo - * @param extraA - * @param extraB - * @param scalarA - * @param scalarB - */ - /* -ND4J_EXPORT void execMetaPredicateShape(Nd4jPointer *extras, - const int opTypeA, - const int opNumA, - const int opTypeB, - const int opNumB, - Nd4jLong N, - void *hX, Nd4jLong *hXShapeBuffer, - void *dX, Nd4jLong *dXShapeBuffer, - void *hY, Nd4jLong *hYShapeBuffer, - void *dY, Nd4jLong *dYShapeBuffer, - void *hZ, Nd4jLong *hZShapeBuffer, - void *dZ, Nd4jLong *dZShapeBuffer, - void *extraA, - void *extraB, - double scalarA, - double scalarB); - -*/ - /** * * @param data @@ -2792,23 +2694,20 @@ public native @Cast("Nd4jPointer") Pointer pointerForAddress(@Cast("Nd4jLong") l * @return */ public native void tear(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo, - @Cast("Nd4jPointer*") PointerPointer targets, @Cast("Nd4jLong*") LongPointer zShapeInfo, - @Cast("Nd4jLong*") LongPointer tadShapeInfo, - @Cast("Nd4jLong*") LongPointer tadOffsets); + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer xShapeInfo, @Cast("Nd4jLong*") LongPointer dxShapeInfo, + @Cast("Nd4jPointer*") PointerPointer targets, @Cast("Nd4jLong*") LongPointer zShapeInfo, + @Cast("Nd4jLong*") LongPointer tadShapeInfo, + @Cast("Nd4jLong*") LongPointer tadOffsets); public native void tear(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo, - Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo, - @Cast("Nd4jPointer*") PointerPointer targets, @Cast("Nd4jLong*") LongBuffer zShapeInfo, - @Cast("Nd4jLong*") LongBuffer tadShapeInfo, - @Cast("Nd4jLong*") LongBuffer tadOffsets); + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer xShapeInfo, @Cast("Nd4jLong*") LongBuffer dxShapeInfo, + @Cast("Nd4jPointer*") PointerPointer targets, @Cast("Nd4jLong*") LongBuffer zShapeInfo, + @Cast("Nd4jLong*") LongBuffer tadShapeInfo, + @Cast("Nd4jLong*") LongBuffer tadOffsets); public native void tear(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo, - Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo, - @Cast("Nd4jPointer*") PointerPointer targets, @Cast("Nd4jLong*") long[] zShapeInfo, - @Cast("Nd4jLong*") long[] tadShapeInfo, - @Cast("Nd4jLong*") long[] tadOffsets); + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] xShapeInfo, @Cast("Nd4jLong*") long[] dxShapeInfo, + @Cast("Nd4jPointer*") PointerPointer targets, @Cast("Nd4jLong*") long[] zShapeInfo, + @Cast("Nd4jLong*") long[] tadShapeInfo, + @Cast("Nd4jLong*") long[] tadOffsets); public native @Cast("Nd4jLong") long encodeBitmap(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer dx, @Cast("Nd4jLong*") LongPointer xShapeInfo, @Cast("Nd4jLong") long N, IntPointer dz, float threshold); public native @Cast("Nd4jLong") long encodeBitmap(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer dx, @Cast("Nd4jLong*") LongBuffer xShapeInfo, @Cast("Nd4jLong") long N, IntBuffer dz, float threshold); @@ -3105,6 +3004,8 @@ public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") bool public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer); public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); public native void setGraphContextOutputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); +public native void setGraphContextInputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer buffer, Pointer shapeInfo, Pointer specialShapeInfo); +public native void setGraphContextOutputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer buffer, Pointer shapeInfo, Pointer specialShapeInfo); public native void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments); public native void setGraphContextTArguments(OpaqueContext ptr, DoubleBuffer arguments, int numberOfArguments); public native void setGraphContextTArguments(OpaqueContext ptr, double[] arguments, int numberOfArguments); @@ -3137,6 +3038,28 @@ public native @Cast("Nd4jPointer") Pointer lcCopyStream(OpaqueLaunchContext lc); public native @Cast("Nd4jPointer") Pointer lcBlasHandle(OpaqueLaunchContext lc); public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc); +public native OpaqueDataBuffer allocateDataBuffer(@Cast("Nd4jLong") long elements, int dataType, @Cast("bool") boolean allocateBoth); +public native OpaqueDataBuffer dbCreateView(OpaqueDataBuffer dataBuffer, @Cast("Nd4jLong") long length, @Cast("Nd4jLong") long offset); +public native @Cast("Nd4jPointer") Pointer dbPrimaryBuffer(OpaqueDataBuffer dataBuffer); +public native @Cast("Nd4jPointer") Pointer dbSpecialBuffer(OpaqueDataBuffer dataBuffer); +public native void dbExpandBuffer(OpaqueDataBuffer dataBuffer, @Cast("Nd4jLong") long elements); +public native void dbAllocatePrimaryBuffer(OpaqueDataBuffer dataBuffer); +public native void dbAllocateSpecialBuffer(OpaqueDataBuffer dataBuffer); +public native void dbSetPrimaryBuffer(OpaqueDataBuffer dataBuffer, @Cast("Nd4jPointer") Pointer primaryBuffer, @Cast("Nd4jLong") long numBytes); +public native void dbSetSpecialBuffer(OpaqueDataBuffer dataBuffer, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong") long numBytes); +public native void dbSyncToSpecial(OpaqueDataBuffer dataBuffer); +public native void dbSyncToPrimary(OpaqueDataBuffer dataBuffer); +public native int dbLocality(OpaqueDataBuffer dataBuffer); +public native int dbDeviceId(OpaqueDataBuffer dataBuffer); +public native void dbSetDeviceId(OpaqueDataBuffer dataBuffer, int deviceId); +public native void dbTickHostRead(OpaqueDataBuffer dataBuffer); +public native void dbTickHostWrite(OpaqueDataBuffer dataBuffer); +public native void dbTickDeviceRead(OpaqueDataBuffer dataBuffer); +public native void dbTickDeviceWrite(OpaqueDataBuffer dataBuffer); +public native void dbClose(OpaqueDataBuffer dataBuffer); +public native void deleteDataBuffer(OpaqueDataBuffer dataBuffer); +public native void dbExpand(OpaqueDataBuffer dataBuffer, @Cast("Nd4jLong") long elements); + public native int binaryLevel(); public native int optimalLevel(); @@ -3633,6 +3556,8 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); // #include // #include // #include +// #include +// #include @@ -3844,10 +3769,13 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); * @param writeList * @param readList */ - // TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list - - // TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list + public static native void registerSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList); + public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/); + public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList); + public static native void registerPrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList); + public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/); + public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList); /** * This method returns buffer pointer offset by given number of elements, wrt own data type @@ -6396,10 +6324,12 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); public native void setInputArray(int index, NDArray array, @Cast("bool") boolean removable/*=false*/); public native void setInputArray(int index, NDArray array); public native void setInputArray(int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); + public native void setInputArray(int index, Pointer databuffer, Pointer shapeInfo, Pointer specialShapeInfo); public native void setOutputArray(int index, NDArray array, @Cast("bool") boolean removable/*=false*/); public native void setOutputArray(int index, NDArray array); public native void setOutputArray(int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); + public native void setOutputArray(int index, Pointer databuffer, Pointer shapeInfo, Pointer specialShapeInfo); public native void setTArguments(DoublePointer arguments, int numberOfArguments); public native void setTArguments(DoubleBuffer arguments, int numberOfArguments); @@ -10274,6 +10204,9 @@ public static final int PREALLOC_SIZE = 33554432; // #include // #include // #include +// #include +// #include +// #include // #include // #include // #include diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java index 8c2109f7c..7b649b488 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java @@ -33,6 +33,7 @@ import org.bytedeco.javacpp.tools.InfoMapper; @Properties(target = "org.nd4j.nativeblas.Nd4jCuda", helper = "org.nd4j.nativeblas.Nd4jCudaHelper", value = {@Platform(define = "LIBND4J_ALL_OPS", include = { "array/DataType.h", + "array/DataBuffer.h", "array/ConstantDescriptor.h", "array/ConstantDataBuffer.h", "array/TadPack.h", @@ -165,6 +166,7 @@ public class Nd4jCudaPresets implements LoadEnabled, InfoMapper { .put(new Info("OpaqueContext").pointerTypes("OpaqueContext")) .put(new Info("OpaqueRandomGenerator").pointerTypes("OpaqueRandomGenerator")) .put(new Info("OpaqueLaunchContext").pointerTypes("OpaqueLaunchContext")) + .put(new Info("OpaqueDataBuffer").pointerTypes("OpaqueDataBuffer")) .put(new Info("const char").valueTypes("byte").pointerTypes("@Cast(\"char*\") String", "@Cast(\"char*\") BytePointer")) .put(new Info("char").valueTypes("char").pointerTypes("@Cast(\"char*\") BytePointer", @@ -185,10 +187,11 @@ public class Nd4jCudaPresets implements LoadEnabled, InfoMapper { "nd4j::graph::FlatResult", "nd4j::graph::FlatVariable", "nd4j::NDArray::subarray").skip()) .put(new Info("std::string").annotations("@StdString").valueTypes("BytePointer", "String") .pointerTypes("@Cast({\"char*\", \"std::string*\"}) BytePointer")) - .put(new Info("std::pair").pointerTypes("IntIntPair").define()) - .put(new Info("std::vector >").pointerTypes("IntVectorVector").define()) + .put(new Info("std::pair").pointerTypes("IntIntPair").define()) + .put(new Info("std::vector >").pointerTypes("IntVectorVector").define()) .put(new Info("std::vector >").pointerTypes("LongVectorVector").define()) - .put(new Info("std::vector").pointerTypes("NDArrayVector").define()) + .put(new Info("std::vector").pointerTypes("NDArrayVector").define()) + .put(new Info("std::vector").pointerTypes("ConstNDArrayVector").define()) .put(new Info("bool").cast().valueTypes("boolean").pointerTypes("BooleanPointer", "boolean[]")) .put(new Info("nd4j::graph::ResultWrapper").base("org.nd4j.nativeblas.ResultWrapperAbstraction").define()) .put(new Info("nd4j::IndicesList").purify()); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/AllocatorTest.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/AllocatorTest.java deleted file mode 100644 index c19adf4ad..000000000 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/AllocatorTest.java +++ /dev/null @@ -1,552 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.jita.allocator; - -import lombok.extern.slf4j.Slf4j; -import lombok.var; -import org.apache.commons.lang3.RandomUtils; -import org.bytedeco.javacpp.Pointer; -import org.junit.Ignore; -import org.junit.Test; -import org.nd4j.jita.allocator.impl.AtomicAllocator; -import org.nd4j.jita.allocator.impl.MemoryTracker; - -import lombok.val; - -import org.nd4j.jita.conf.CudaEnvironment; -import org.nd4j.jita.flow.FlowController; -import org.nd4j.jita.memory.impl.CudaFullCachingProvider; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; -import org.nd4j.linalg.api.memory.enums.MemoryKind; -import org.nd4j.linalg.api.memory.enums.MirroringPolicy; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.executors.ExecutorServiceProvider; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.api.memory.enums.AllocationPolicy; -import org.nd4j.jita.memory.impl.CudaDirectProvider; -import org.nd4j.jita.memory.impl.CudaCachingZeroProvider; -import org.nd4j.jita.allocator.utils.AllocationUtils; -import org.nd4j.jita.allocator.enums.AllocationStatus; -import org.nd4j.jita.allocator.impl.AllocationPoint; -import org.nd4j.jita.allocator.enums.AllocationStatus; -import org.nd4j.jita.allocator.impl.AllocationShape; -import org.nd4j.linalg.primitives.Pair; - -import java.util.*; -import java.util.concurrent.Callable; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Future; - -import static org.junit.Assert.*; - -@Slf4j -@Ignore("AB 2019/05/23 - Getting stuck (tests never finishing) on CI - see issue #7657") -public class AllocatorTest { - private static final long SAFETY_OFFSET = 1024L; - - @Test - public void testCounters() { - int deviceId = 0; - MemoryTracker tracker = new MemoryTracker(); - - assertTrue(0 == tracker.getAllocatedAmount(deviceId)); - assertTrue(0 == tracker.getCachedAmount(deviceId)); - //assertTrue(0 == tracker.getTotalMemory(deviceId)); - - tracker.incrementAllocatedAmount(deviceId, 10); - assertTrue(10 == tracker.getAllocatedAmount(deviceId)); - - tracker.incrementCachedAmount(deviceId, 5); - assertTrue(5 == tracker.getCachedAmount(deviceId)); - - tracker.decrementAllocatedAmount(deviceId, 5); - assertTrue(5 == tracker.getAllocatedAmount(deviceId)); - - tracker.decrementCachedAmount(deviceId, 5); - assertTrue(0 == tracker.getCachedAmount(deviceId)); - - //assertTrue(0 == tracker.getTotalMemory(deviceId)); - - for (int e = 0; e < Nd4j.getAffinityManager().getNumberOfDevices(); e++) { - val ttl = tracker.getTotalMemory(e); - log.info("Device_{} {} bytes", e, ttl); - assertNotEquals(0, ttl); - } - } - - @Test - public void testWorkspaceInitSize() { - - long initSize = 1024; - MemoryTracker tracker = MemoryTracker.getInstance(); - - WorkspaceConfiguration workspaceConfig = WorkspaceConfiguration.builder() - .policyAllocation(AllocationPolicy.STRICT) - .initialSize(initSize) - .build(); - - try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(workspaceConfig, "test121")) { - assertEquals(initSize + SAFETY_OFFSET, tracker.getWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - } - - val ws = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("test121"); - ws.destroyWorkspace(); - - assertEquals(0, tracker.getWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - } - - - @Test - public void testWorkspaceSpilledSize() { - - long initSize = 0; - MemoryTracker tracker = MemoryTracker.getInstance(); - - WorkspaceConfiguration workspaceConfig = WorkspaceConfiguration.builder() - .policyAllocation(AllocationPolicy.STRICT) - .initialSize(initSize) - .build(); - - try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(workspaceConfig, "test99323")) { - assertEquals(0L, tracker.getWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - - val array = Nd4j.createFromArray(1.f, 2.f, 3.f, 4.f); - - assertEquals(array.length() * array.data().getElementSize(), tracker.getWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - } - - val ws = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("test99323"); - ws.destroyWorkspace(); - - assertEquals(0, tracker.getWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - } - - @Test - public void testWorkspaceSpilledSizeHost() { - - long initSize = 0; - MemoryTracker tracker = MemoryTracker.getInstance(); - - WorkspaceConfiguration workspaceConfig = WorkspaceConfiguration.builder() - .policyAllocation(AllocationPolicy.STRICT) - .policyMirroring(MirroringPolicy.HOST_ONLY) - .initialSize(initSize) - .build(); - - try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(workspaceConfig, "test99323222")) { - assertEquals(0L, tracker.getWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - - val array = Nd4j.createFromArray(1.f, 2.f, 3.f, 4.f); - - assertEquals(0, tracker.getWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - } - - val ws = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("test99323222"); - ws.destroyWorkspace(); - - assertEquals(0, tracker.getWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - } - - - @Ignore - @Test - public void testWorkspaceAlloc() { - - long initSize = 0; - long allocSize = 48; - - val workspaceConfig = WorkspaceConfiguration.builder() - .policyAllocation(AllocationPolicy.STRICT) - .initialSize(initSize) - .policyMirroring(MirroringPolicy.HOST_ONLY) // Commenting this out makes it so that assert is not triggered (for at least 40 secs or so...) - .build(); - - try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(workspaceConfig, "test")) { - final INDArray zeros = Nd4j.zeros(allocSize, 'c'); - System.out.println("Alloc1:" + MemoryTracker.getInstance().getWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - assertTrue(allocSize == - MemoryTracker.getInstance().getWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - } - assertTrue(allocSize == - MemoryTracker.getInstance().getWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - /*Nd4j.getWorkspaceManager().destroyWorkspace(ws); - assertTrue(0L == - MemoryTracker.getInstance().getWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()));*/ - } - - @Test - public void testDirectProvider() { - INDArray input = Nd4j.zeros(1024); - CudaDirectProvider provider = new CudaDirectProvider(); - AllocationShape shape = AllocationUtils.buildAllocationShape(input); - AllocationPoint point = new AllocationPoint(); - point.setShape(shape); - - val allocBefore = MemoryTracker.getInstance().getAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - val cachedBefore = MemoryTracker.getInstance().getCachedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - - val pointers = provider.malloc(shape, point, AllocationStatus.DEVICE); - point.setPointers(pointers); - - System.out.println(MemoryTracker.getInstance().getAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - System.out.println(MemoryTracker.getInstance().getCachedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - - val allocMiddle = MemoryTracker.getInstance().getAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - val cachedMiddle = MemoryTracker.getInstance().getCachedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - - provider.free(point); - - System.out.println(MemoryTracker.getInstance().getAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - System.out.println(MemoryTracker.getInstance().getCachedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - - val allocAfter = MemoryTracker.getInstance().getAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - val cachedAfter = MemoryTracker.getInstance().getCachedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - - assertTrue(allocBefore < allocMiddle); - assertEquals(allocBefore, allocAfter); - - assertEquals(cachedBefore, cachedMiddle); - assertEquals(cachedBefore, cachedAfter); - } - - @Test - public void testZeroCachingProvider() { - INDArray input = Nd4j.zeros(1024); - CudaCachingZeroProvider provider = new CudaCachingZeroProvider(); - AllocationShape shape = AllocationUtils.buildAllocationShape(input); - AllocationPoint point = new AllocationPoint(); - point.setShape(shape); - - val allocBefore = MemoryTracker.getInstance().getAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - val cachedBefore = MemoryTracker.getInstance().getCachedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - - val pointers = provider.malloc(shape, point, AllocationStatus.DEVICE); - point.setPointers(pointers); - - System.out.println(MemoryTracker.getInstance().getAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - System.out.println(MemoryTracker.getInstance().getCachedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - - val allocMiddle = MemoryTracker.getInstance().getAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - val cachedMiddle = MemoryTracker.getInstance().getCachedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - - provider.free(point); - - System.out.println(MemoryTracker.getInstance().getAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - System.out.println(MemoryTracker.getInstance().getCachedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - - val allocAfter = MemoryTracker.getInstance().getAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - val cachedAfter = MemoryTracker.getInstance().getCachedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - - assertTrue(allocBefore < allocMiddle); - assertEquals(allocBefore, allocAfter); - - assertEquals(cachedBefore, cachedMiddle); - assertEquals(cachedBefore, cachedAfter); - } - - @Test - public void testFullCachingProvider() { - INDArray input = Nd4j.zeros(1024); - val provider = new CudaFullCachingProvider(); - AllocationShape shape = AllocationUtils.buildAllocationShape(input); - AllocationPoint point = new AllocationPoint(); - point.setShape(shape); - - val allocBefore = MemoryTracker.getInstance().getAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - val cachedBefore = MemoryTracker.getInstance().getCachedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - - val pointers = provider.malloc(shape, point, AllocationStatus.DEVICE); - point.setPointers(pointers); - - System.out.println(MemoryTracker.getInstance().getAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - System.out.println(MemoryTracker.getInstance().getCachedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - - val allocMiddle = MemoryTracker.getInstance().getAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - val cachedMiddle = MemoryTracker.getInstance().getCachedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - - provider.free(point); - - System.out.println(MemoryTracker.getInstance().getAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - System.out.println(MemoryTracker.getInstance().getCachedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - - val allocAfter = MemoryTracker.getInstance().getAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - val cachedAfter = MemoryTracker.getInstance().getCachedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - - assertTrue(allocBefore < allocMiddle); - assertEquals(allocBefore, allocAfter); - - //assertEquals(0, cachedBefore); - //assertEquals(0, cachedMiddle); - //assertEquals(shape.getNumberOfBytes(), cachedAfter); - - assertEquals(cachedBefore, cachedMiddle); - assertTrue(cachedBefore < cachedAfter); - } - - @Test - public void testCyclicCreation() throws Exception { - Nd4j.create(100); - - log.info("Approximate free memory: {}", MemoryTracker.getInstance().getApproximateFreeMemory(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - log.info("Real free memory: {}", MemoryTracker.getInstance().getPreciseFreeMemory(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - - val timeStart = System.currentTimeMillis(); - - while (true) { - //val array = Nd4j.create(DataType.FLOAT, 1000, 1000); - val array = Nd4j.create(DataType.FLOAT, RandomUtils.nextInt(100, 1000), RandomUtils.nextInt(100, 1000)); - - val timeEnd = System.currentTimeMillis(); - if (timeEnd - timeStart > 5 * 60 * 1000) { - log.info("Exiting..."); - break; - } - } - - while (true) { - log.info("Cached device memory: {}", MemoryTracker.getInstance().getCachedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - log.info("Active device memory: {}", MemoryTracker.getInstance().getAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread())); - log.info("Cached host memory: {}", MemoryTracker.getInstance().getCachedHostAmount()); - log.info("Active host memory: {}", MemoryTracker.getInstance().getAllocatedHostAmount()); - - System.gc(); - Thread.sleep(30000); - } - } - - @Test - public void testAllocations() { - INDArray x = Nd4j.create(DataType.FLOAT, 10, 5); - assertArrayEquals(new long[]{10, 5}, x.shape()); - - for (DataType dataType : DataType.values()) { - for (int i = 0; i < 10; ++i) { - - x = Nd4j.create(DataType.FLOAT, 10 * i + 1, 5 * i + 2); - assertArrayEquals(new long[]{10 * i + 1, 5 * i + 2}, x.shape()); - - val pointX = AtomicAllocator.getInstance().getAllocationPoint(x.shapeInfoDataBuffer()); - assertNotNull(pointX); - assertTrue(x.shapeInfoDataBuffer().isConstant()); - - assertNotNull(pointX.getHostPointer()); - assertNotNull(pointX.getDevicePointer()); - - assertEquals(64, pointX.getShape().getNumberOfBytes()); - } - } - } - - @Test - public void testAllocations1() { - INDArray x = Nd4j.zeros(1,10); - - for (int i = 0; i < 100000; ++i) { - INDArray toAdd = Nd4j.ones(1,10); - x.putRow(i+1, toAdd); - } - - assertTrue(x.shapeInfoDataBuffer().isConstant()); - - val pointX = AtomicAllocator.getInstance().getAllocationPoint(x.shapeInfoDataBuffer()); - assertNotNull(pointX); - - assertNotNull(pointX); - assertTrue(x.shapeInfoDataBuffer().isConstant()); - - assertNotNull(pointX.getHostPointer()); - assertNotNull(pointX.getDevicePointer()); - - assertEquals(64, pointX.getShape().getNumberOfBytes()); - } - - @Test - public void testReallocate() { - INDArray x = Nd4j.create(DataType.FLOAT, 10, 5); - var pointX = AtomicAllocator.getInstance().getAllocationPoint(x.data()); - - assertNotNull(pointX); - - assertEquals(200, pointX.getShape().getNumberOfBytes()); - - val hostP = pointX.getHostPointer(); - val deviceP = pointX.getDevicePointer(); - - assertEquals(50, x.data().capacity()); - x.data().reallocate(500); - - pointX = AtomicAllocator.getInstance().getAllocationPoint(x.data()); - - assertEquals(500, x.data().capacity()); - assertEquals(2000, pointX.getShape().getNumberOfBytes()); - - assertNotEquals(hostP, pointX.getHostPointer()); - assertNotEquals(deviceP, pointX.getDevicePointer()); - } - - @Test - public void testDataMigration() { - - for (boolean p2pEnabled : new boolean[]{true, false}) { - - CudaEnvironment.getInstance().getConfiguration().allowCrossDeviceAccess(p2pEnabled); - - Thread[] threads = new Thread[4]; - List> sumsPerList = new ArrayList<>(); - List lst = new ArrayList<>(); - - for (int i = 0; i < 4; ++i) { - threads[i] = new Thread() { - @Override - public void run() { - INDArray x = Nd4j.rand(1, 10); - Pair pair = new Pair<>(); - pair.setFirst(Nd4j.sum(x)); - pair.setSecond(x); - sumsPerList.add(pair); - lst.add(x); - } - }; - threads[i].start(); - } - - try { - for (val thread : threads) { - thread.join(); - } - } catch (InterruptedException e) { - log.info("Interrupted"); - } - - Collections.shuffle(lst); - - for (int i = 0; i < lst.size(); ++i) { - INDArray data = lst.get(i); - - for (int j = 0; j < sumsPerList.size(); ++j) { - if (sumsPerList.get(j).getFirst().equals(data)) - assertEquals(sumsPerList.get(j).getSecond(), data); - - } - } - } - } - - - @Ignore - @Test - public void testHostFallback() { - // Take device memory - long bytesFree = MemoryTracker.getInstance().getApproximateFreeMemory(0); - Pointer p = Nd4j.getMemoryManager().allocate((long)(bytesFree*0.75), MemoryKind.DEVICE, true); - - // Fallback to host - INDArray x1 = Nd4j.create(1, (long)(bytesFree*0.15)); - val pointX = AtomicAllocator.getInstance().getAllocationPoint(x1.shapeInfoDataBuffer()); - - assertNotNull(pointX); - assertNotNull(pointX.getHostPointer()); - assertNotNull(pointX.getDevicePointer()); - - Nd4j.getMemoryManager().release(p, MemoryKind.DEVICE); - } - - @Test - public void testAffinityGuarantees() { - ExecutorService service = ExecutorServiceProvider.getExecutorService(); - final INDArray steady = Nd4j.rand(1,100); - Map deviceData = new HashMap<>(); - - Future>[] results = new Future[10]; - for (int i = 0; i < results.length; ++i) { - results[i] = service.submit(new Callable>() { - @Override - public List call() { - List retVal = new ArrayList<>(); - for (int i = 0; i < 100; ++i) { - INDArray x = Nd4j.rand(1, 100); - System.out.println("Device for x:" + Nd4j.getAffinityManager().getDeviceForArray(x)); - System.out.println("Device for steady: " + Nd4j.getAffinityManager().getDeviceForArray(steady)); - deviceData.put(x, Nd4j.getAffinityManager().getDeviceForArray(x)); - deviceData.put(steady, Nd4j.getAffinityManager().getDeviceForArray(steady)); - retVal.add(x); - } - Thread[] innerThreads = new Thread[4]; - for (int k = 0; k < 4; ++k) { - innerThreads[k] = new Thread() { - @Override - public void run() { - for (val res : retVal) { - assertEquals(deviceData.get(res), Nd4j.getAffinityManager().getDeviceForArray(res)); - assertEquals(deviceData.get(steady), Nd4j.getAffinityManager().getDeviceForArray(steady)); - } - } - }; - innerThreads[k].start(); - } - try { - for (int k = 0; k < 4; ++k) { - innerThreads[k].join(); - } - } catch (InterruptedException e) { - log.info(e.getMessage()); - } - return retVal; - } - }); - - try { - List resArray = results[i].get(); - for (val res : resArray) { - assertEquals(deviceData.get(res), Nd4j.getAffinityManager().getDeviceForArray(res)); - assertEquals(deviceData.get(steady), Nd4j.getAffinityManager().getDeviceForArray(steady)); - } - } catch (Exception e) { - log.info(e.getMessage()); - } - } - } - - @Test - public void testEventsRelease() { - FlowController controller = AtomicAllocator.getInstance().getFlowController(); - long currEventsNumber = controller.getEventsProvider().getEventsNumber(); - - INDArray x = Nd4j.rand(1,10); - controller.prepareAction(x); - assertEquals(currEventsNumber+1, controller.getEventsProvider().getEventsNumber()); - - INDArray arg1 = Nd4j.rand(1,100); - INDArray arg2 = Nd4j.rand(1,200); - INDArray arg3 = Nd4j.rand(1,300); - controller.prepareAction(x, arg1, arg2, arg3); - assertEquals(currEventsNumber+5, controller.getEventsProvider().getEventsNumber()); - } - - @Test - public void testDataBuffers() { - INDArray x = Nd4j.create(DataType.FLOAT, 10, 5); - val pointX = AtomicAllocator.getInstance().getAllocationPoint(x.shapeInfoDataBuffer()); - assertEquals(50, x.data().capacity()); - x.data().destroy(); - assertNull(x.data()); - assertEquals(64, pointX.getShape().getNumberOfBytes()); - System.out.println(pointX.getHostPointer()); - System.out.println(pointX.getDevicePointer()); - } -} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBufferTest.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBufferTest.java index 2f5d53a40..09c1ebb04 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBufferTest.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBufferTest.java @@ -2,150 +2,201 @@ package org.nd4j.linalg.jcublas.buffer; import lombok.extern.slf4j.Slf4j; import lombok.val; +import org.junit.Before; import org.junit.Test; import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.workspace.CudaWorkspace; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.util.PrintVariable; import org.nd4j.linalg.factory.Nd4j; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.util.ArrayList; +import java.util.concurrent.atomic.AtomicInteger; + import static org.junit.Assert.*; @Slf4j public class BaseCudaDataBufferTest { - @Test - public void testShapeCache_1() { - val x = Nd4j.create(DataType.FLOAT, 3, 5); - - assertEquals(DataType.FLOAT, x.dataType()); - assertArrayEquals(new long[]{3, 5}, x.shape()); - assertArrayEquals(new long[]{5, 1}, x.stride()); - assertEquals(1, x.elementWiseStride()); - assertEquals('c', x.ordering()); - - val pair = Nd4j.getShapeInfoProvider().createShapeInformation(x.shape(), x.stride(), x.elementWiseStride(), x.ordering(), x.dataType(), x.isEmpty()); - val db = pair.getFirst(); - val jvm = pair.getSecond(); - - log.info("array shapeInfo: {}", x.shapeInfoJava()); - log.info("direct shapeInfo: {}", jvm); - - val pointX = AtomicAllocator.getInstance().getAllocationPoint(x.shapeInfoDataBuffer()); - val pointM = AtomicAllocator.getInstance().getAllocationPoint(db); - - assertNotNull(pointX); - assertNotNull(pointM); - - assertNotNull(pointX.getHostPointer()); - assertNotNull(pointX.getDevicePointer()); - - assertNotNull(pointM.getHostPointer()); - assertNotNull(pointM.getDevicePointer()); - - - log.info("X hPtr: {}; dPtr: {}", pointX.getHostPointer().address(), pointX.getDevicePointer().address()); - log.info("M hPtr: {}; dPtr: {}", pointM.getHostPointer().address(), pointM.getDevicePointer().address()); - - assertEquals(pointM.getHostPointer().address(), pointX.getHostPointer().address()); - assertEquals(pointM.getDevicePointer().address(), pointX.getDevicePointer().address()); - - assertArrayEquals(x.shapeInfoJava(), jvm); + @Before + public void setUp() { + // } @Test - public void testTadCache_1() { - val x = Nd4j.create(DataType.FLOAT, 3, 5); - val row = x.getRow(1); - val tad = x.tensorAlongDimension(1, 1); + public void testBasicAllocation_1() { + val array = Nd4j.create(DataType.FLOAT, 5); - val pointX = AtomicAllocator.getInstance().getAllocationPoint(row.shapeInfoDataBuffer()); - val pointM = AtomicAllocator.getInstance().getAllocationPoint(tad.shapeInfoDataBuffer()); + // basic validation + assertNotNull(array); + assertNotNull(array.data()); + assertNotNull(((BaseCudaDataBuffer) array.data()).getOpaqueDataBuffer()); - assertNotNull(pointX); - assertNotNull(pointM); + // shape part + assertArrayEquals(new long[]{1, 5, 1, 8192, 1, 99}, array.shapeInfoJava()); + assertArrayEquals(new long[]{1, 5, 1, 8192, 1, 99}, array.shapeInfoDataBuffer().asLong()); - assertNotNull(pointX.getHostPointer()); - assertNotNull(pointX.getDevicePointer()); - - assertNotNull(pointM.getHostPointer()); - assertNotNull(pointM.getDevicePointer()); - - - log.info("X hPtr: {}; dPtr: {}", pointX.getHostPointer().address(), pointX.getDevicePointer().address()); - log.info("M hPtr: {}; dPtr: {}", pointM.getHostPointer().address(), pointM.getDevicePointer().address()); - - assertEquals(pointM.getHostPointer().address(), pointX.getHostPointer().address()); - assertEquals(pointM.getDevicePointer().address(), pointX.getDevicePointer().address()); - - assertArrayEquals(row.shapeInfoJava(), tad.shapeInfoJava()); - } - - - @Test - public void testHostAllocation_1() { - val x = Nd4j.create(DataType.FLOAT, 3, 5); - - val pointX = AtomicAllocator.getInstance().getAllocationPoint(x.data()); - - assertNotNull(pointX); - - assertNull(pointX.getHostPointer()); - assertNotNull(pointX.getDevicePointer()); - - - x.getDouble(0); - - assertNotNull(pointX.getHostPointer()); + // arrat as full of zeros at this point + assertArrayEquals(new float[] {0.f, 0.f, 0.f, 0.f, 0.f}, array.data().asFloat(), 1e-5f); } @Test - public void testHostAllocation_2() { - val x = Nd4j.createFromArray(new double[]{1, 2, 3, 4, 5}); + public void testBasicAllocation_2() { + val array = Nd4j.createFromArray(1.f, 2.f, 3.f, 4.f, 5.f); - val pointX = AtomicAllocator.getInstance().getAllocationPoint(x.data()); + // basic validation + assertNotNull(array); + assertNotNull(array.data()); + assertNotNull(((BaseCudaDataBuffer) array.data()).getOpaqueDataBuffer()); - assertNotNull(pointX); + // shape part + assertArrayEquals(new long[]{1, 5, 1, 8192, 1, 99}, array.shapeInfoJava()); + assertArrayEquals(new long[]{1, 5, 1, 8192, 1, 99}, array.shapeInfoDataBuffer().asLong()); - assertNull(pointX.getHostPointer()); - assertNotNull(pointX.getDevicePointer()); - - val sum = x.sumNumber().doubleValue(); - - assertNull(pointX.getHostPointer()); - - assertEquals(15, sum, 1e-5); - - x.getDouble(0); - - assertNotNull(pointX.getHostPointer()); + // arrat as full of values at this point + assertArrayEquals(new float[] {1.f, 2.f, 3.f, 4.f, 5.f}, array.data().asFloat(), 1e-5f); } @Test - public void testHostAllocation_3() { - val wsConf = WorkspaceConfiguration.builder() - .initialSize(10 * 1024 * 1024) - .build(); + public void testBasicView_1() { + val array = Nd4j.createFromArray(1.f, 2.f, 3.f, 4.f, 5.f, 6.f).reshape(3, 2); - try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(wsConf, "someworkspaceid")) { - val x = Nd4j.create(DataType.DOUBLE, 3, 5); + // basic validation + assertNotNull(array); + assertNotNull(array.data()); + assertNotNull(((BaseCudaDataBuffer) array.data()).getOpaqueDataBuffer()); - val pointX = AtomicAllocator.getInstance().getAllocationPoint(x.data()); + // checking TAD equality + val row = array.getRow(1); + assertArrayEquals(new float[]{3.0f, 4.0f}, row.data().dup().asFloat(), 1e-5f); + } - assertNotNull(pointX); + @Test + public void testScalar_1() { + val scalar = Nd4j.scalar(119.f); - assertNull(pointX.getHostPointer()); - assertNotNull(pointX.getDevicePointer()); + // basic validation + assertNotNull(scalar); + assertNotNull(scalar.data()); + assertNotNull(((BaseCudaDataBuffer) scalar.data()).getOpaqueDataBuffer()); - assertEquals(0, ((CudaWorkspace) ws).getHostOffset()); + // shape part + assertArrayEquals(new long[]{0, 8192, 1, 99}, scalar.shapeInfoJava()); + assertArrayEquals(new long[]{0, 8192, 1, 99}, scalar.shapeInfoDataBuffer().asLong()); - x.getDouble(0); + // pointers part + val devPtr = AtomicAllocator.getInstance().getPointer(scalar.data()); + val hostPtr = AtomicAllocator.getInstance().getHostPointer(scalar.data()); + // dev pointer supposed to exist, and host pointer is not + assertNotNull(devPtr); + assertNull(hostPtr); - assertEquals(ws.getPrimaryOffset(), ((CudaWorkspace) ws).getHostOffset()); - assertNotEquals(0, ws.getPrimaryOffset()); + assertEquals(119.f, scalar.getFloat(0), 1e-5f); + } - assertNotNull(pointX.getHostPointer()); + @Test + public void testSerDe_1() throws Exception { + val array = Nd4j.createFromArray(1.f, 2.f, 3.f, 4.f, 5.f, 6.f); + val baos = new ByteArrayOutputStream(); + + Nd4j.write(baos, array); + INDArray restored = Nd4j.read(new ByteArrayInputStream(baos.toByteArray())); + + // basic validation + assertNotNull(restored); + assertNotNull(restored.data()); + assertNotNull(((BaseCudaDataBuffer) restored.data()).getOpaqueDataBuffer()); + + // shape part + assertArrayEquals(new long[]{1, 6, 1, 8192, 1, 99}, restored.shapeInfoJava()); + assertArrayEquals(new long[]{1, 6, 1, 8192, 1, 99}, restored.shapeInfoDataBuffer().asLong()); + + // data equality + assertArrayEquals(array.data().asFloat(), restored.data().asFloat(), 1e-5f); + } + + @Test + public void testBasicOpInvocation_1() { + val array1 = Nd4j.createFromArray(1.f, 2.f, 3.f, 4.f, 5.f, 6.f); + val array2 = Nd4j.createFromArray(1.f, 2.f, 3.f, 4.f, 5.f, 6.f); + + // shape pointers must be equal here + val devPtr1 = AtomicAllocator.getInstance().getPointer(array1.shapeInfoDataBuffer()); + val devPtr2 = AtomicAllocator.getInstance().getPointer(array2.shapeInfoDataBuffer()); + + val hostPtr1 = AtomicAllocator.getInstance().getHostPointer(array1.shapeInfoDataBuffer()); + val hostPtr2 = AtomicAllocator.getInstance().getHostPointer(array2.shapeInfoDataBuffer()); + + // pointers must be equal on host and device, since we have shape cache + assertEquals(devPtr1.address(), devPtr2.address()); + assertEquals(hostPtr1.address(), hostPtr2.address()); + + assertEquals(array1, array2); + } + + @Test + public void testBasicOpInvocation_2() { + val array1 = Nd4j.createFromArray(1.f, 200.f, 3.f, 4.f, 5.f, 6.f); + val array2 = Nd4j.createFromArray(1.f, 2.f, 3.f, 4.f, 5.f, 6.f); + + assertNotEquals(array1, array2); + } + + @Test + public void testBasicOpInvocation_3() { + val array = Nd4j.create(DataType.FLOAT, 6); + val exp = Nd4j.createFromArray(1.f, 1.f, 1.f, 1.f, 1.f, 1.f); + + array.addi(1.0f); + + assertEquals(exp, array); + } + + @Test + public void testCustomOpInvocation_1() { + val array = Nd4j.createFromArray(1.f, 2.f, 3.f, 4.f, 5.f, 6.f); + + Nd4j.exec(new PrintVariable(array, true)); + Nd4j.exec(new PrintVariable(array)); + } + + @Test + public void testMultiDeviceMigration_1() throws Exception { + if (Nd4j.getAffinityManager().getNumberOfDevices() < 2) + return; + + // creating all arrays within main thread context + val list = new ArrayList(); + for (int e = 0; e < Nd4j.getAffinityManager().getNumberOfDevices(); e++) + list.add(Nd4j.create(DataType.FLOAT, 3, 5)); + + val cnt = new AtomicInteger(0); + + // now we're creating threads + for (int e = 0; e < Nd4j.getAffinityManager().getNumberOfDevices(); e++) { + val f = e; + val t = new Thread(new Runnable() { + @Override + public void run() { + // issuing one operation, just to see how migration works + list.get(f).addi(1.0f); + + // synchronizing immediately + Nd4j.getExecutioner().commit(); + cnt.incrementAndGet(); + } + }); + + t.start(); + t.join(); } + + // there shoul dbe no exceptions during execution + assertEquals(Nd4j.getAffinityManager().getNumberOfDevices(), cnt.get()); } } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/pom.xml b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/pom.xml index 64be62442..48cdc3e03 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/pom.xml +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/pom.xml @@ -61,6 +61,7 @@ ${mkl.version}-${javacpp-presets.version} ${dependency.platform2} + org.nd4j nd4j-native-api diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java index 5f48b3c64..38cb6610e 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java @@ -28,6 +28,9 @@ import org.nd4j.linalg.api.ops.impl.shape.Concat; import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper; import org.nd4j.linalg.api.shape.options.ArrayType; import org.nd4j.linalg.compression.CompressionUtils; +import org.nd4j.linalg.cpu.nativecpu.buffer.BaseCpuDataBuffer; +import org.nd4j.linalg.cpu.nativecpu.buffer.LongBuffer; +import org.nd4j.linalg.cpu.nativecpu.buffer.Utf8Buffer; import org.nd4j.linalg.primitives.Pair; import org.bytedeco.javacpp.*; import org.nd4j.linalg.api.memory.MemoryWorkspace; @@ -558,11 +561,9 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { } nativeOps.tear(null, - tensor.data().pointer(), (LongPointer) tensor.shapeInfoDataBuffer().pointer(), - null, null, + ((BaseCpuDataBuffer) tensor.data()).getOpaqueDataBuffer(), (LongPointer) tensor.shapeInfoDataBuffer().pointer(), null, targets, (LongPointer) result[0].shapeInfoDataBuffer().pointer(), - (LongPointer) tadBuffers.getFirst().pointer(), - new LongPointerWrapper(tadBuffers.getSecond().pointer()) + (LongPointer) tadBuffers.getFirst().pointer(), new LongPointerWrapper(tadBuffers.getSecond().pointer()) ); if (nativeOps.lastErrorCode() != 0) @@ -701,10 +702,8 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { nativeOps.pullRows(dummy, - source.data().addressPointer(), (LongPointer) source.shapeInfoDataBuffer().addressPointer(), - null, null, - ret.data().addressPointer(), (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), - null, null, + ((BaseCpuDataBuffer) source.data()).getOpaqueDataBuffer(), (LongPointer) source.shapeInfoDataBuffer().addressPointer(), null, + ((BaseCpuDataBuffer) ret.data()).getOpaqueDataBuffer(), (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), null, indexes.length, pIndex, (LongPointer) hostTadShapeInfo, new LongPointerWrapper(hostTadOffsets), diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuTADManager.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuTADManager.java index 36599c859..c48178055 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuTADManager.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuTADManager.java @@ -18,20 +18,13 @@ package org.nd4j.linalg.cpu.nativecpu; import lombok.NonNull; import lombok.val; -import org.bytedeco.javacpp.IntPointer; -import org.bytedeco.javacpp.LongPointer; -import org.bytedeco.javacpp.Pointer; import org.nd4j.linalg.api.buffer.DataBuffer; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.buffer.IntBuffer; -import org.nd4j.linalg.api.buffer.LongBuffer; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.cache.ConstantHandler; import org.nd4j.linalg.cache.TADManager; import org.nd4j.linalg.cache.TadDescriptor; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.Pair; -import org.nd4j.nativeblas.LongPointerWrapper; import org.nd4j.nativeblas.NativeOps; import java.util.Arrays; diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/NDArray.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/NDArray.java index a6cd47fb0..aae332d78 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/NDArray.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/NDArray.java @@ -17,8 +17,12 @@ package org.nd4j.linalg.cpu.nativecpu; +import com.google.flatbuffers.FlatBufferBuilder; import lombok.val; +import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.Pointer; +import org.nd4j.base.Preconditions; +import org.nd4j.graph.FlatArray; import org.nd4j.linalg.api.buffer.*; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.BaseNDArray; @@ -27,10 +31,17 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.JvmShapeInfo; import org.nd4j.linalg.api.ops.performance.PerformanceTracker; import org.nd4j.linalg.api.shape.LongShapeDescriptor; +import org.nd4j.linalg.cpu.nativecpu.buffer.DoubleBuffer; +import org.nd4j.linalg.cpu.nativecpu.buffer.FloatBuffer; +import org.nd4j.linalg.cpu.nativecpu.buffer.LongBuffer; +import org.nd4j.linalg.cpu.nativecpu.buffer.Utf8Buffer; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.memory.MemcpyDirection; import org.nd4j.linalg.workspace.WorkspaceUtils; +import java.io.ByteArrayOutputStream; +import java.io.DataOutputStream; +import java.io.IOException; import java.util.List; @@ -488,4 +499,36 @@ public class NDArray extends BaseNDArray { public LongShapeDescriptor shapeDescriptor() { return LongShapeDescriptor.fromShape(shape(), stride(), elementWiseStride(), ordering(), dataType(), isEmpty()); } + + protected int stringBuffer(FlatBufferBuilder builder, DataBuffer buffer) { + Preconditions.checkArgument(buffer.dataType() == DataType.UTF8, "This method can be called on UTF8 buffers only"); + try { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(bos); + + val numWords = this.length(); + val ub = (Utf8Buffer) buffer; + // writing length first + val t = length(); + val ptr = (BytePointer) ub.pointer(); + + // now write all strings as bytes + for (int i = 0; i < ub.length(); i++) { + dos.writeByte(ptr.get(i)); + } + + val bytes = bos.toByteArray(); + return FlatArray.createBufferVector(builder, bytes); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public String getString(long index) { + if (!isS()) + throw new UnsupportedOperationException("This method is usable only on String dataType, but got [" + this.dataType() + "]"); + + return ((Utf8Buffer) data).getString(index); + } } diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BFloat16Buffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BFloat16Buffer.java similarity index 91% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BFloat16Buffer.java rename to nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BFloat16Buffer.java index f2c8a9202..819b30339 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BFloat16Buffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BFloat16Buffer.java @@ -14,11 +14,14 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.linalg.api.buffer; +package org.nd4j.linalg.cpu.nativecpu.buffer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.indexer.Indexer; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import java.nio.ByteBuffer; @@ -28,7 +31,7 @@ import java.nio.ByteBuffer; * * @author Adam Gibson */ -public class BFloat16Buffer extends BaseDataBuffer { +public class BFloat16Buffer extends BaseCpuDataBuffer { /** * Meant for creating another view of a buffer @@ -50,6 +53,10 @@ public class BFloat16Buffer extends BaseDataBuffer { } + public BFloat16Buffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + public BFloat16Buffer(long length, boolean initialize) { super(length, initialize); } @@ -111,18 +118,6 @@ public class BFloat16Buffer extends BaseDataBuffer { super(data, copy, offset); } - public BFloat16Buffer(ByteBuffer buffer, int length) { - super(buffer, length); - } - - public BFloat16Buffer(ByteBuffer buffer, int length, long offset) { - super(buffer, length, offset); - } - - public BFloat16Buffer(byte[] data, int length) { - super(data, length); - } - public BFloat16Buffer(float[] floats, boolean copy) { super(floats, copy); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java new file mode 100644 index 000000000..ec4a0e51a --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java @@ -0,0 +1,939 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.cpu.nativecpu.buffer; + +import lombok.val; +import org.bytedeco.javacpp.*; +import org.bytedeco.javacpp.indexer.*; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.buffer.util.AllocUtil; +import org.nd4j.linalg.api.memory.Deallocatable; +import org.nd4j.linalg.api.memory.Deallocator; +import org.nd4j.linalg.api.memory.MemoryWorkspace; +import org.nd4j.linalg.api.memory.pointers.PagedPointer; +import org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.memory.deallocation.DeallocatorService; +import org.nd4j.nativeblas.NativeOps; +import org.nd4j.nativeblas.NativeOpsHolder; +import org.nd4j.nativeblas.OpaqueDataBuffer; + +import java.nio.ByteBuffer; + +import static org.nd4j.linalg.api.buffer.DataType.INT16; +import static org.nd4j.linalg.api.buffer.DataType.INT8; + +/** + * Base implementation for DataBuffer for CPU-like backend + * + * @author raver119@gmail.com + */ +public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallocatable { + + protected transient OpaqueDataBuffer ptrDataBuffer; + + private final long instanceId = Nd4j.getDeallocatorService().nextValue(); + + protected BaseCpuDataBuffer() { + + } + + + @Override + public String getUniqueId() { + return "BCDB_" + instanceId; + } + + @Override + public Deallocator deallocator() { + return new CpuDeallocator(this); + } + + public OpaqueDataBuffer getOpaqueDataBuffer() { + return ptrDataBuffer; + } + + @Override + public int targetDevice() { + // TODO: once we add NUMA support this might change. Or might not. + return 0; + } + + + /** + * + * @param length + * @param elementSize + */ + public BaseCpuDataBuffer(long length, int elementSize) { + if (length < 1) + throw new IllegalArgumentException("Length must be >= 1"); + initTypeAndSize(); + allocationMode = AllocUtil.getAllocationModeFromContext(); + this.length = length; + this.underlyingLength = length; + this.elementSize = (byte) elementSize; + + if (dataType() != DataType.UTF8) + ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(length, dataType(), false); + + if (dataType() == DataType.DOUBLE) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asDoublePointer(); + + indexer = DoubleIndexer.create((DoublePointer) pointer); + } else if (dataType() == DataType.FLOAT) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asFloatPointer(); + + setIndexer(FloatIndexer.create((FloatPointer) pointer)); + } else if (dataType() == DataType.INT32) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asIntPointer(); + + setIndexer(IntIndexer.create((IntPointer) pointer)); + } else if (dataType() == DataType.LONG) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asLongPointer(); + + setIndexer(LongIndexer.create((LongPointer) pointer)); + } else if (dataType() == DataType.SHORT) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer(); + + setIndexer(ShortIndexer.create((ShortPointer) pointer)); + } else if (dataType() == DataType.BYTE) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBytePointer(); + + setIndexer(ByteIndexer.create((BytePointer) pointer)); + } else if (dataType() == DataType.UBYTE) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBytePointer(); + + setIndexer(UByteIndexer.create((BytePointer) pointer)); + } else if (dataType() == DataType.UTF8) { + ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(length, INT8, false); + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBytePointer(); + + setIndexer(ByteIndexer.create((BytePointer) pointer)); + } + + Nd4j.getDeallocatorService().pickObject(this); + } + + /** + * + * @param length + * @param elementSize + */ + public BaseCpuDataBuffer(int length, int elementSize, long offset) { + this(length, elementSize); + this.offset = offset; + this.originalOffset = offset; + this.length = length - offset; + this.underlyingLength = length; + } + + + protected BaseCpuDataBuffer(DataBuffer underlyingBuffer, long length, long offset) { + super(underlyingBuffer, length, offset); + + // for vew we need "externally managed" pointer and deallocator registration + ptrDataBuffer = ((BaseCpuDataBuffer) underlyingBuffer).ptrDataBuffer.createView(length * underlyingBuffer.getElementSize(), offset * underlyingBuffer.getElementSize()); + Nd4j.getDeallocatorService().pickObject(this); + + + // update pointer now + actualizePointerAndIndexer(); + } + + protected BaseCpuDataBuffer(ByteBuffer buffer, DataType dtype, long length, long offset) { + this(length, Nd4j.sizeOfDataType(dtype)); + + Pointer temp = null; + + switch (dataType()){ + case DOUBLE: + temp = new DoublePointer(buffer.asDoubleBuffer()); + break; + case FLOAT: + temp = new FloatPointer(buffer.asFloatBuffer()); + break; + case HALF: + temp = new ShortPointer(buffer.asShortBuffer()); + break; + case LONG: + temp = new LongPointer(buffer.asLongBuffer()); + break; + case INT: + temp = new IntPointer(buffer.asIntBuffer()); + break; + case SHORT: + temp = new ShortPointer(buffer.asShortBuffer()); + break; + case UBYTE: //Fall through + case BYTE: + temp = new BytePointer(buffer); + break; + case BOOL: + temp = new BooleanPointer(length()); + break; + case UTF8: + temp = new BytePointer(length()); + break; + case BFLOAT16: + temp = new ShortPointer(length()); + break; + case UINT16: + temp = new ShortPointer(length()); + break; + case UINT32: + temp = new IntPointer(length()); + break; + case UINT64: + temp = new LongPointer(length()); + break; + } + + val ptr = ptrDataBuffer.primaryBuffer(); + + if (offset > 0) + temp = new PagedPointer(temp.address() + offset * getElementSize()); + + Pointer.memcpy(ptr, temp, length * Nd4j.sizeOfDataType(dtype)); + } + + @Override + public void pointerIndexerByCurrentType(DataType currentType) { + + type = currentType; + + if (ptrDataBuffer == null) { + ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(length(), type, false); + Nd4j.getDeallocatorService().pickObject(this); + } + + actualizePointerAndIndexer(); + } + + /** + * Instantiate a buffer with the given length + * + * @param length the length of the buffer + */ + protected BaseCpuDataBuffer(long length) { + this(length, true); + } + + protected BaseCpuDataBuffer(long length, boolean initialize) { + if (length < 0) + throw new IllegalArgumentException("Length must be >= 0"); + initTypeAndSize(); + this.length = length; + this.underlyingLength = length; + allocationMode = AllocUtil.getAllocationModeFromContext(); + if (length < 0) + throw new IllegalArgumentException("Unable to create a buffer of length <= 0"); + + if (dataType() != DataType.UTF8) + ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(length, dataType(), false); + + if (dataType() == DataType.DOUBLE) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asDoublePointer(); + + indexer = DoubleIndexer.create((DoublePointer) pointer); + + if (initialize) + fillPointerWithZero(); + } else if (dataType() == DataType.FLOAT) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asFloatPointer(); + + setIndexer(FloatIndexer.create((FloatPointer) pointer)); + + if (initialize) + fillPointerWithZero(); + + } else if (dataType() == DataType.HALF) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer(); + + setIndexer(HalfIndexer.create((ShortPointer) pointer)); + + if (initialize) + fillPointerWithZero(); + } else if (dataType() == DataType.BFLOAT16) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer(); + + setIndexer(Bfloat16Indexer.create((ShortPointer) pointer)); + + if (initialize) + fillPointerWithZero(); + } else if (dataType() == DataType.INT) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asIntPointer(); + + setIndexer(IntIndexer.create((IntPointer) pointer)); + if (initialize) + fillPointerWithZero(); + } else if (dataType() == DataType.LONG) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asLongPointer(); + + setIndexer(LongIndexer.create((LongPointer) pointer)); + + if (initialize) + fillPointerWithZero(); + } else if (dataType() == DataType.BYTE) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBytePointer(); + + setIndexer(ByteIndexer.create((BytePointer) pointer)); + + if (initialize) + fillPointerWithZero(); + } else if (dataType() == DataType.SHORT) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer(); + + setIndexer(ShortIndexer.create((ShortPointer) pointer)); + + if (initialize) + fillPointerWithZero(); + } else if (dataType() == DataType.UBYTE) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBytePointer(); + + setIndexer(UByteIndexer.create((BytePointer) pointer)); + + if (initialize) + fillPointerWithZero(); + } else if (dataType() == DataType.UINT16) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer(); + + setIndexer(UShortIndexer.create((ShortPointer) pointer)); + + if (initialize) + fillPointerWithZero(); + } else if (dataType() == DataType.UINT32) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asIntPointer(); + + // FIXME: we need unsigned indexer here + setIndexer(IntIndexer.create((IntPointer) pointer)); + + if (initialize) + fillPointerWithZero(); + } else if (dataType() == DataType.UINT64) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asLongPointer(); + + // FIXME: we need unsigned indexer here + setIndexer(LongIndexer.create((LongPointer) pointer)); + + if (initialize) + fillPointerWithZero(); + } else if (dataType() == DataType.BOOL) { + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBoolPointer(); + + setIndexer(BooleanIndexer.create((BooleanPointer) pointer)); + + if (initialize) + fillPointerWithZero(); + } else if (dataType() == DataType.UTF8) { + // we are allocating buffer as INT8 intentionally + ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(length(), INT8, false); + pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length()).asBytePointer(); + + setIndexer(ByteIndexer.create((BytePointer) pointer)); + + if (initialize) + fillPointerWithZero(); + } + + Nd4j.getDeallocatorService().pickObject(this); + } + + public void actualizePointerAndIndexer() { + val cptr = ptrDataBuffer.primaryBuffer(); + + // skip update if pointers are equal + if (cptr != null && pointer != null && cptr.address() == pointer.address()) + return; + + val t = dataType(); + if (t == DataType.BOOL) { + pointer = new PagedPointer(cptr, length).asBoolPointer(); + setIndexer(BooleanIndexer.create((BooleanPointer) pointer)); + } else if (t == DataType.UBYTE) { + pointer = new PagedPointer(cptr, length).asBytePointer(); + setIndexer(UByteIndexer.create((BytePointer) pointer)); + } else if (t == DataType.BYTE) { + pointer = new PagedPointer(cptr, length).asBytePointer(); + setIndexer(ByteIndexer.create((BytePointer) pointer)); + } else if (t == DataType.UINT16) { + pointer = new PagedPointer(cptr, length).asShortPointer(); + setIndexer(UShortIndexer.create((ShortPointer) pointer)); + } else if (t == DataType.SHORT) { + pointer = new PagedPointer(cptr, length).asShortPointer(); + setIndexer(ShortIndexer.create((ShortPointer) pointer)); + } else if (t == DataType.UINT32) { + pointer = new PagedPointer(cptr, length).asIntPointer(); + setIndexer(IntIndexer.create((IntPointer) pointer)); + } else if (t == DataType.INT) { + pointer = new PagedPointer(cptr, length).asIntPointer(); + setIndexer(IntIndexer.create((IntPointer) pointer)); + } else if (t == DataType.UINT64) { + pointer = new PagedPointer(cptr, length).asLongPointer(); + setIndexer(LongIndexer.create((LongPointer) pointer)); + } else if (t == DataType.LONG) { + pointer = new PagedPointer(cptr, length).asLongPointer(); + setIndexer(LongIndexer.create((LongPointer) pointer)); + } else if (t == DataType.BFLOAT16) { + pointer = new PagedPointer(cptr, length).asShortPointer(); + setIndexer(Bfloat16Indexer.create((ShortPointer) pointer)); + } else if (t == DataType.HALF) { + pointer = new PagedPointer(cptr, length).asShortPointer(); + setIndexer(HalfIndexer.create((ShortPointer) pointer)); + } else if (t == DataType.FLOAT) { + pointer = new PagedPointer(cptr, length).asFloatPointer(); + setIndexer(FloatIndexer.create((FloatPointer) pointer)); + } else if (t == DataType.DOUBLE) { + pointer = new PagedPointer(cptr, length).asDoublePointer(); + setIndexer(DoubleIndexer.create((DoublePointer) pointer)); + } else if (t == DataType.UTF8) { + pointer = new PagedPointer(cptr, length()).asBytePointer(); + setIndexer(ByteIndexer.create((BytePointer) pointer)); + } else + throw new IllegalArgumentException("Unknown datatype: " + dataType()); + } + + @Override + public Pointer addressPointer() { + // we're fetching actual pointer right from C++ + val tempPtr = new PagedPointer(ptrDataBuffer.primaryBuffer()); + + switch (this.type) { + case DOUBLE: return tempPtr.asDoublePointer(); + case FLOAT: return tempPtr.asFloatPointer(); + case UINT16: + case SHORT: + case BFLOAT16: + case HALF: return tempPtr.asShortPointer(); + case UINT32: + case INT: return tempPtr.asIntPointer(); + case UBYTE: + case BYTE: return tempPtr.asBytePointer(); + case UINT64: + case LONG: return tempPtr.asLongPointer(); + case BOOL: return tempPtr.asBoolPointer(); + default: return tempPtr.asBytePointer(); + } + } + + protected BaseCpuDataBuffer(long length, boolean initialize, MemoryWorkspace workspace) { + if (length < 1) + throw new IllegalArgumentException("Length must be >= 1"); + initTypeAndSize(); + this.length = length; + this.underlyingLength = length; + allocationMode = AllocUtil.getAllocationModeFromContext(); + + + + if (length < 0) + throw new IllegalArgumentException("Unable to create a buffer of length <= 0"); + + // creating empty native DataBuffer + ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, dataType(), false); + + if (dataType() == DataType.DOUBLE) { + attached = true; + parentWorkspace = workspace; + + pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asDoublePointer(); //new DoublePointer(length()); + indexer = DoubleIndexer.create((DoublePointer) pointer); + + } else if (dataType() == DataType.FLOAT) { + attached = true; + parentWorkspace = workspace; + + pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asFloatPointer(); //new FloatPointer(length()); + setIndexer(FloatIndexer.create((FloatPointer) pointer)); + + } else if (dataType() == DataType.HALF) { + attached = true; + parentWorkspace = workspace; + + pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asShortPointer(); //new FloatPointer(length()); + setIndexer(HalfIndexer.create((ShortPointer) pointer)); + + } else if (dataType() == DataType.BFLOAT16) { + attached = true; + parentWorkspace = workspace; + + pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asShortPointer(); //new FloatPointer(length()); + setIndexer(Bfloat16Indexer.create((ShortPointer) pointer)); + } else if (dataType() == DataType.INT) { + attached = true; + parentWorkspace = workspace; + + pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asIntPointer(); //new IntPointer(length()); + setIndexer(IntIndexer.create((IntPointer) pointer)); + + } else if (dataType() == DataType.UINT32) { + attached = true; + parentWorkspace = workspace; + + // FIXME: need unsigned indexer here + pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asIntPointer(); //new IntPointer(length()); + setIndexer(IntIndexer.create((IntPointer) pointer)); + + } else if (dataType() == DataType.UINT64) { + attached = true; + parentWorkspace = workspace; + + // FIXME: need unsigned indexer here + pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asLongPointer(); //new IntPointer(length()); + setIndexer(LongIndexer.create((LongPointer) pointer)); + + } else if (dataType() == DataType.LONG) { + attached = true; + parentWorkspace = workspace; + + pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asLongPointer(); //new LongPointer(length()); + setIndexer(LongIndexer.create((LongPointer) pointer)); + } else if (dataType() == DataType.BYTE) { + attached = true; + parentWorkspace = workspace; + + pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asBytePointer(); //new LongPointer(length()); + setIndexer(ByteIndexer.create((BytePointer) pointer)); + } else if (dataType() == DataType.UBYTE) { + attached = true; + parentWorkspace = workspace; + + pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asBytePointer(); //new LongPointer(length()); + setIndexer(UByteIndexer.create((BytePointer) pointer)); + } else if (dataType() == DataType.UINT16) { + attached = true; + parentWorkspace = workspace; + + pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asShortPointer(); //new IntPointer(length()); + setIndexer(UShortIndexer.create((ShortPointer) pointer)); + + } else if (dataType() == DataType.SHORT) { + attached = true; + parentWorkspace = workspace; + + pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asShortPointer(); //new LongPointer(length()); + setIndexer(ShortIndexer.create((ShortPointer) pointer)); + } else if (dataType() == DataType.BOOL) { + attached = true; + parentWorkspace = workspace; + + pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asBoolPointer(); //new LongPointer(length()); + setIndexer(BooleanIndexer.create((BooleanPointer) pointer)); + } else if (dataType() == DataType.UTF8) { + attached = true; + parentWorkspace = workspace; + + pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asLongPointer(); //new LongPointer(length()); + setIndexer(LongIndexer.create((LongPointer) pointer)); + } + + // storing pointer into native DataBuffer + ptrDataBuffer.setPrimaryBuffer(pointer, length); + + // adding deallocator reference + Nd4j.getDeallocatorService().pickObject(this); + + workspaceGenerationId = workspace.getGenerationId(); + } + + public BaseCpuDataBuffer(Pointer pointer, Indexer indexer, long length) { + super(pointer, indexer, length); + + ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, type, false); + ptrDataBuffer.setPrimaryBuffer(this.pointer, length); + Nd4j.getDeallocatorService().pickObject(this);; + } + + /** + * + * @param data + * @param copy + */ + public BaseCpuDataBuffer(float[] data, boolean copy, long offset) { + this(data, copy); + this.offset = offset; + this.originalOffset = offset; + this.length = data.length - offset; + this.underlyingLength = data.length; + + } + + public BaseCpuDataBuffer(float[] data, boolean copy, long offset, MemoryWorkspace workspace) { + this(data, copy, workspace); + this.offset = offset; + this.originalOffset = offset; + this.length = data.length - offset; + this.underlyingLength = data.length; + } + + /** + * + * @param data + * @param copy + */ + public BaseCpuDataBuffer(float[] data, boolean copy) { + allocationMode = AllocUtil.getAllocationModeFromContext(); + initTypeAndSize(); + + pointer = new FloatPointer(data); + + // creating & registering native DataBuffer + ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(data.length, DataType.FLOAT, false); + ptrDataBuffer.setPrimaryBuffer(pointer, data.length); + Nd4j.getDeallocatorService().pickObject(this); + + setIndexer(FloatIndexer.create((FloatPointer) pointer)); + //wrappedBuffer = pointer.asByteBuffer(); + + length = data.length; + underlyingLength = data.length; + } + + public BaseCpuDataBuffer(float[] data, boolean copy, MemoryWorkspace workspace) { + allocationMode = AllocUtil.getAllocationModeFromContext(); + length = data.length; + underlyingLength = data.length; + attached = true; + parentWorkspace = workspace; + + initTypeAndSize(); + + //log.info("Allocating FloatPointer from array of {} elements", data.length); + + pointer = workspace.alloc(data.length * getElementSize(), dataType(), false).asFloatPointer().put(data); + + this.ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, dataType(), false); + this.ptrDataBuffer.setPrimaryBuffer(pointer, this.length); + Nd4j.getDeallocatorService().pickObject(this); + + workspaceGenerationId = workspace.getGenerationId(); + setIndexer(FloatIndexer.create((FloatPointer) pointer)); + //wrappedBuffer = pointer.asByteBuffer(); + } + + public BaseCpuDataBuffer(double[] data, boolean copy, MemoryWorkspace workspace) { + allocationMode = AllocUtil.getAllocationModeFromContext(); + length = data.length; + underlyingLength = data.length; + attached = true; + parentWorkspace = workspace; + + initTypeAndSize(); + + //log.info("Allocating FloatPointer from array of {} elements", data.length); + + pointer = workspace.alloc(data.length * getElementSize(), dataType(), false).asDoublePointer().put(data); + + this.ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, dataType(), false); + this.ptrDataBuffer.setPrimaryBuffer(pointer, this.length); + Nd4j.getDeallocatorService().pickObject(this); + + workspaceGenerationId = workspace.getGenerationId(); + indexer = DoubleIndexer.create((DoublePointer) pointer); + //wrappedBuffer = pointer.asByteBuffer(); + } + + + public BaseCpuDataBuffer(int[] data, boolean copy, MemoryWorkspace workspace) { + allocationMode = AllocUtil.getAllocationModeFromContext(); + length = data.length; + underlyingLength = data.length; + attached = true; + parentWorkspace = workspace; + + initTypeAndSize(); + + //log.info("Allocating FloatPointer from array of {} elements", data.length); + + pointer = workspace.alloc(data.length * getElementSize(), dataType(), false).asIntPointer().put(data); + + this.ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, dataType(), false); + this.ptrDataBuffer.setPrimaryBuffer(pointer, this.length); + Nd4j.getDeallocatorService().pickObject(this); + + workspaceGenerationId = workspace.getGenerationId(); + indexer = IntIndexer.create((IntPointer) pointer); + //wrappedBuffer = pointer.asByteBuffer(); + } + + public BaseCpuDataBuffer(long[] data, boolean copy, MemoryWorkspace workspace) { + allocationMode = AllocUtil.getAllocationModeFromContext(); + length = data.length; + underlyingLength = data.length; + attached = true; + parentWorkspace = workspace; + + initTypeAndSize(); + + //log.info("Allocating FloatPointer from array of {} elements", data.length); + + pointer = workspace.alloc(data.length * getElementSize(), dataType(), false).asLongPointer().put(data); + + this.ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, dataType(), false); + this.ptrDataBuffer.setPrimaryBuffer(pointer, this.length); + Nd4j.getDeallocatorService().pickObject(this); + + workspaceGenerationId = workspace.getGenerationId(); + indexer = LongIndexer.create((LongPointer) pointer); + //wrappedBuffer = pointer.asByteBuffer(); + } + + + /** + * + * @param data + * @param copy + */ + public BaseCpuDataBuffer(double[] data, boolean copy, long offset) { + this(data, copy); + this.offset = offset; + this.originalOffset = offset; + this.underlyingLength = data.length; + this.length = underlyingLength - offset; + } + + public BaseCpuDataBuffer(double[] data, boolean copy, long offset, MemoryWorkspace workspace) { + this(data, copy, workspace); + this.offset = offset; + this.originalOffset = offset; + this.underlyingLength = data.length; + this.length = underlyingLength - offset; + } + + /** + * + * @param data + * @param copy + */ + public BaseCpuDataBuffer(double[] data, boolean copy) { + allocationMode = AllocUtil.getAllocationModeFromContext(); + initTypeAndSize(); + + pointer = new DoublePointer(data); + indexer = DoubleIndexer.create((DoublePointer) pointer); + + // creating & registering native DataBuffer + ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(data.length, DataType.DOUBLE, false); + ptrDataBuffer.setPrimaryBuffer(pointer, data.length); + Nd4j.getDeallocatorService().pickObject(this); + + length = data.length; + underlyingLength = data.length; + } + + + /** + * + * @param data + * @param copy + */ + public BaseCpuDataBuffer(int[] data, boolean copy, long offset) { + this(data, copy); + this.offset = offset; + this.originalOffset = offset; + this.length = data.length - offset; + this.underlyingLength = data.length; + } + + /** + * + * @param data + * @param copy + */ + public BaseCpuDataBuffer(int[] data, boolean copy) { + allocationMode = AllocUtil.getAllocationModeFromContext(); + initTypeAndSize(); + + pointer = new IntPointer(data); + setIndexer(IntIndexer.create((IntPointer) pointer)); + + // creating & registering native DataBuffer + ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(data.length, DataType.INT32, false); + ptrDataBuffer.setPrimaryBuffer(pointer, data.length); + Nd4j.getDeallocatorService().pickObject(this); + + length = data.length; + underlyingLength = data.length; + } + + /** + * + * @param data + * @param copy + */ + public BaseCpuDataBuffer(long[] data, boolean copy) { + allocationMode = AllocUtil.getAllocationModeFromContext(); + initTypeAndSize(); + + pointer = new LongPointer(data); + setIndexer(LongIndexer.create((LongPointer) pointer)); + + // creating & registering native DataBuffer + ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(data.length, DataType.INT64, false); + ptrDataBuffer.setPrimaryBuffer(pointer, data.length); + Nd4j.getDeallocatorService().pickObject(this); + + length = data.length; + underlyingLength = data.length; + } + + + /** + * + * @param data + */ + public BaseCpuDataBuffer(double[] data) { + this(data, true); + } + + /** + * + * @param data + */ + public BaseCpuDataBuffer(int[] data) { + this(data, true); + } + + /** + * + * @param data + */ + public BaseCpuDataBuffer(float[] data) { + this(data, true); + } + + public BaseCpuDataBuffer(float[] data, MemoryWorkspace workspace) { + this(data, true, workspace); + } + + /** + * Reallocate the native memory of the buffer + * @param length the new length of the buffer + * @return this databuffer + * */ + @Override + public DataBuffer reallocate(long length) { + val oldPointer = ptrDataBuffer.primaryBuffer(); + + if (isAttached()) { + val capacity = length * getElementSize(); + val nPtr = getParentWorkspace().alloc(capacity, dataType(), false); + this.ptrDataBuffer.setPrimaryBuffer(nPtr, length); + + switch (dataType()) { + case BOOL: + pointer = nPtr.asBoolPointer(); + indexer = BooleanIndexer.create((BooleanPointer) pointer); + break; + case UTF8: + case BYTE: + case UBYTE: + pointer = nPtr.asBytePointer(); + indexer = ByteIndexer.create((BytePointer) pointer); + break; + case UINT16: + case SHORT: + pointer = nPtr.asShortPointer(); + indexer = ShortIndexer.create((ShortPointer) pointer); + break; + case UINT32: + case INT: + pointer = nPtr.asIntPointer(); + indexer = IntIndexer.create((IntPointer) pointer); + break; + case DOUBLE: + pointer = nPtr.asDoublePointer(); + indexer = DoubleIndexer.create((DoublePointer) pointer); + break; + case FLOAT: + pointer = nPtr.asFloatPointer(); + indexer = FloatIndexer.create((FloatPointer) pointer); + break; + case HALF: + pointer = nPtr.asShortPointer(); + indexer = HalfIndexer.create((ShortPointer) pointer); + break; + case BFLOAT16: + pointer = nPtr.asShortPointer(); + indexer = Bfloat16Indexer.create((ShortPointer) pointer); + break; + case UINT64: + case LONG: + pointer = nPtr.asLongPointer(); + indexer = LongIndexer.create((LongPointer) pointer); + break; + } + + Pointer.memcpy(pointer, oldPointer, this.length() * getElementSize()); + workspaceGenerationId = getParentWorkspace().getGenerationId(); + } else { + this.ptrDataBuffer.expand(length); + val nPtr = new PagedPointer(this.ptrDataBuffer.primaryBuffer(), length); + + switch (dataType()) { + case BOOL: + pointer = nPtr.asBoolPointer(); + indexer = BooleanIndexer.create((BooleanPointer) pointer); + break; + case UTF8: + case BYTE: + case UBYTE: + pointer = nPtr.asBytePointer(); + indexer = ByteIndexer.create((BytePointer) pointer); + break; + case UINT16: + case SHORT: + pointer = nPtr.asShortPointer(); + indexer = ShortIndexer.create((ShortPointer) pointer); + break; + case UINT32: + case INT: + pointer = nPtr.asIntPointer(); + indexer = IntIndexer.create((IntPointer) pointer); + break; + case DOUBLE: + pointer = nPtr.asDoublePointer(); + indexer = DoubleIndexer.create((DoublePointer) pointer); + break; + case FLOAT: + pointer = nPtr.asFloatPointer(); + indexer = FloatIndexer.create((FloatPointer) pointer); + break; + case HALF: + pointer = nPtr.asShortPointer(); + indexer = HalfIndexer.create((ShortPointer) pointer); + break; + case BFLOAT16: + pointer = nPtr.asShortPointer(); + indexer = Bfloat16Indexer.create((ShortPointer) pointer); + break; + case UINT64: + case LONG: + pointer = nPtr.asLongPointer(); + indexer = LongIndexer.create((LongPointer) pointer); + break; + } + } + + this.underlyingLength = length; + this.length = length; + return this; + } + +} diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BoolBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BoolBuffer.java similarity index 91% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BoolBuffer.java rename to nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BoolBuffer.java index 51ea5ca25..6f1bb5f99 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BoolBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BoolBuffer.java @@ -14,11 +14,14 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.linalg.api.buffer; +package org.nd4j.linalg.cpu.nativecpu.buffer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.indexer.Indexer; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import java.nio.ByteBuffer; @@ -28,7 +31,7 @@ import java.nio.ByteBuffer; * * @author Adam Gibson */ -public class BoolBuffer extends BaseDataBuffer { +public class BoolBuffer extends BaseCpuDataBuffer { /** * Meant for creating another view of a buffer @@ -47,7 +50,10 @@ public class BoolBuffer extends BaseDataBuffer { */ public BoolBuffer(long length) { super(length); + } + public BoolBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); } public BoolBuffer(long length, boolean initialize) { @@ -111,18 +117,6 @@ public class BoolBuffer extends BaseDataBuffer { super(data, copy, offset); } - public BoolBuffer(ByteBuffer buffer, int length) { - super(buffer, length); - } - - public BoolBuffer(ByteBuffer buffer, int length, long offset) { - super(buffer, length, offset); - } - - public BoolBuffer(byte[] data, int length) { - super(data, length); - } - public BoolBuffer(float[] floats, boolean copy) { super(floats, copy); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/CpuDeallocator.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/CpuDeallocator.java new file mode 100644 index 000000000..3b8a46fa6 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/CpuDeallocator.java @@ -0,0 +1,44 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.cpu.nativecpu.buffer; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.linalg.api.memory.Deallocator; +import org.nd4j.nativeblas.NativeOpsHolder; +import org.nd4j.nativeblas.OpaqueDataBuffer; + +/** + * This class is responsible for OpaqueDataBuffer deletion on native side, once it's not used anymore in Java + * + * @author raver119@gmail.com + */ +@Slf4j +public class CpuDeallocator implements Deallocator { + private OpaqueDataBuffer opaqueDataBuffer; + + public CpuDeallocator(BaseCpuDataBuffer buffer) { + opaqueDataBuffer = buffer.getOpaqueDataBuffer(); + } + + @Override + public void deallocate() { + if (opaqueDataBuffer == null) + throw new RuntimeException("opaqueDataBuffer is null"); + + NativeOpsHolder.getInstance().getDeviceNativeOps().deleteDataBuffer(opaqueDataBuffer); + } +} diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/factory/DefaultDataBufferFactory.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/DefaultDataBufferFactory.java similarity index 93% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/factory/DefaultDataBufferFactory.java rename to nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/DefaultDataBufferFactory.java index 65d605e00..54b02e309 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/factory/DefaultDataBufferFactory.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/DefaultDataBufferFactory.java @@ -14,7 +14,7 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.linalg.api.buffer.factory; +package org.nd4j.linalg.cpu.nativecpu.buffer; import lombok.NonNull; import org.bytedeco.javacpp.DoublePointer; @@ -26,6 +26,7 @@ import org.bytedeco.javacpp.indexer.FloatIndexer; import org.bytedeco.javacpp.indexer.Indexer; import org.bytedeco.javacpp.indexer.IntIndexer; import org.nd4j.linalg.api.buffer.*; +import org.nd4j.linalg.api.buffer.factory.DataBufferFactory; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.util.ArrayUtil; @@ -93,20 +94,6 @@ public class DefaultDataBufferFactory implements DataBufferFactory { return null; } - @Override - public DataBuffer createInt(long offset, ByteBuffer buffer, int length) { - return new IntBuffer(buffer, length, offset); - } - - @Override - public DataBuffer createFloat(long offset, ByteBuffer buffer, int length) { - return new FloatBuffer(buffer, length, offset); - } - - @Override - public DataBuffer createDouble(long offset, ByteBuffer buffer, int length) { - return new DoubleBuffer(buffer, length, offset); - } @Override public DataBuffer createDouble(long offset, int length) { @@ -236,25 +223,6 @@ public class DefaultDataBufferFactory implements DataBufferFactory { return new IntBuffer(ArrayUtil.toInts(data), copy, offset); } - @Override - public DataBuffer createInt(ByteBuffer buffer, int length) { - return new IntBuffer(buffer, length); - } - - @Override - public DataBuffer createLong(ByteBuffer buffer, int length) { - return new LongBuffer(buffer, length); - } - - @Override - public DataBuffer createFloat(ByteBuffer buffer, int length) { - return new FloatBuffer(buffer, length); - } - - @Override - public DataBuffer createDouble(ByteBuffer buffer, int length) { - return new DoubleBuffer(buffer, length); - } @Override public DataBuffer createDouble(long length) { @@ -281,6 +249,42 @@ public class DefaultDataBufferFactory implements DataBufferFactory { return new FloatBuffer(length, initialize, workspace); } + @Override + public DataBuffer create(ByteBuffer underlyingBuffer, DataType dataType, long length, long offset) { + switch (dataType) { + case DOUBLE: + return new DoubleBuffer(underlyingBuffer, dataType, length, offset); + case FLOAT: + return new FloatBuffer(underlyingBuffer, dataType, length, offset); + case HALF: + return new HalfBuffer(underlyingBuffer, dataType, length, offset); + case BFLOAT16: + return new BFloat16Buffer(underlyingBuffer, dataType, length, offset); + case LONG: + return new LongBuffer(underlyingBuffer, dataType, length, offset); + case INT: + return new IntBuffer(underlyingBuffer, dataType, length, offset); + case SHORT: + return new Int16Buffer(underlyingBuffer, dataType, length, offset); + case UBYTE: + return new UInt8Buffer(underlyingBuffer, dataType, length, offset); + case UINT16: + return new UInt16Buffer(underlyingBuffer, dataType, length, offset); + case UINT32: + return new UInt32Buffer(underlyingBuffer, dataType, length, offset); + case UINT64: + return new UInt64Buffer(underlyingBuffer, dataType, length, offset); + case BYTE: + return new Int8Buffer(underlyingBuffer, dataType, length, offset); + case BOOL: + return new BoolBuffer(underlyingBuffer, dataType, length, offset); + case UTF8: + return new Utf8Buffer(underlyingBuffer, dataType, length, offset); + default: + throw new IllegalStateException("Unknown datatype used: [" + dataType + "]"); + } + } + @Override public DataBuffer create(@NonNull DataType dataType, long length, boolean initialize) { switch (dataType) { @@ -310,11 +314,11 @@ public class DefaultDataBufferFactory implements DataBufferFactory { return new Int8Buffer(length, initialize); case BOOL: return new BoolBuffer(length, initialize); + case UTF8: + return new Utf8Buffer(length, true); default: throw new IllegalStateException("Unknown datatype used: [" + dataType + "]"); - } - } @Override @@ -540,16 +544,6 @@ public class DefaultDataBufferFactory implements DataBufferFactory { return createDouble(data, true); } - @Override - public DataBuffer createDouble(byte[] data, int length) { - return new DoubleBuffer(ByteBuffer.wrap(data), length); - } - - @Override - public DataBuffer createFloat(byte[] data, int length) { - return new FloatBuffer(ByteBuffer.wrap(data), length); - } - @Override public DataBuffer createFloat(double[] data) { return createFloat(data, true); @@ -958,18 +952,6 @@ public class DefaultDataBufferFactory implements DataBufferFactory { throw new UnsupportedOperationException("FP16 isn't supported for CPU yet"); } - /** - * Creates a half-precision data buffer - * - * @param offset - * @param data the data to create the buffer from - * @param length - * @return the new buffer - */ - @Override - public DataBuffer createHalf(long offset, byte[] data, int length) { - throw new UnsupportedOperationException("FP16 isn't supported for CPU yet"); - } /** * Creates a half-precision data buffer @@ -983,30 +965,6 @@ public class DefaultDataBufferFactory implements DataBufferFactory { throw new UnsupportedOperationException("FP16 isn't supported for CPU yet"); } - /** - * Creates a half-precision data buffer - * - * @param buffer - * @param length - * @return the new buffer - */ - @Override - public DataBuffer createHalf(ByteBuffer buffer, int length) { - throw new UnsupportedOperationException("FP16 isn't supported for CPU yet"); - } - - /** - * Creates a half-precision data buffer - * - * @param data - * @param length - * @return - */ - @Override - public DataBuffer createHalf(byte[] data, int length) { - throw new UnsupportedOperationException("FP16 isn't supported for CPU yet"); - } - @Override public DataBuffer createHalf(long length, boolean initialize, MemoryWorkspace workspace) { throw new UnsupportedOperationException("FP16 isn't supported for CPU yet"); @@ -1046,4 +1004,8 @@ public class DefaultDataBufferFactory implements DataBufferFactory { public Class doubleBufferClass() { return DoubleBuffer.class; } + + public DataBuffer createUtf8Buffer(byte[] data, long product) { + return new Utf8Buffer(data, product); + } } diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DoubleBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/DoubleBuffer.java similarity index 91% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DoubleBuffer.java rename to nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/DoubleBuffer.java index 8bd4bd6a1..25d1997d1 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DoubleBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/DoubleBuffer.java @@ -14,11 +14,14 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.linalg.api.buffer; +package org.nd4j.linalg.cpu.nativecpu.buffer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.indexer.Indexer; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import java.nio.ByteBuffer; @@ -28,7 +31,7 @@ import java.nio.ByteBuffer; * * @author Adam Gibson */ -public class DoubleBuffer extends BaseDataBuffer { +public class DoubleBuffer extends BaseCpuDataBuffer { /** * Meant for creating another view of a buffer * @@ -40,6 +43,10 @@ public class DoubleBuffer extends BaseDataBuffer { super(pointer, indexer, length); } + public DoubleBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + public DoubleBuffer(long length) { super(length); } @@ -100,18 +107,6 @@ public class DoubleBuffer extends BaseDataBuffer { super(data, copy, offset); } - public DoubleBuffer(ByteBuffer buffer, int length) { - super(buffer, length); - } - - public DoubleBuffer(ByteBuffer buffer, int length, long offset) { - super(buffer, length, offset); - } - - public DoubleBuffer(byte[] data, int length) { - super(data, length); - } - public DoubleBuffer(double[] doubles, boolean copy) { super(doubles, copy); } diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/FloatBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/FloatBuffer.java similarity index 91% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/FloatBuffer.java rename to nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/FloatBuffer.java index 5b598c920..1a6d2846a 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/FloatBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/FloatBuffer.java @@ -14,11 +14,14 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.linalg.api.buffer; +package org.nd4j.linalg.cpu.nativecpu.buffer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.indexer.Indexer; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import java.nio.ByteBuffer; @@ -28,7 +31,7 @@ import java.nio.ByteBuffer; * * @author Adam Gibson */ -public class FloatBuffer extends BaseDataBuffer { +public class FloatBuffer extends BaseCpuDataBuffer { /** * Meant for creating another view of a buffer @@ -50,6 +53,10 @@ public class FloatBuffer extends BaseDataBuffer { } + public FloatBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + public FloatBuffer(long length, boolean initialize) { super(length, initialize); } @@ -111,18 +118,6 @@ public class FloatBuffer extends BaseDataBuffer { super(data, copy, offset); } - public FloatBuffer(ByteBuffer buffer, int length) { - super(buffer, length); - } - - public FloatBuffer(ByteBuffer buffer, int length, long offset) { - super(buffer, length, offset); - } - - public FloatBuffer(byte[] data, int length) { - super(data, length); - } - public FloatBuffer(float[] floats, boolean copy) { super(floats, copy); } diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/HalfBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/HalfBuffer.java similarity index 91% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/HalfBuffer.java rename to nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/HalfBuffer.java index d2cb2cfcc..1fdb338b2 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/HalfBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/HalfBuffer.java @@ -14,11 +14,14 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.linalg.api.buffer; +package org.nd4j.linalg.cpu.nativecpu.buffer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.indexer.Indexer; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import java.nio.ByteBuffer; @@ -28,7 +31,7 @@ import java.nio.ByteBuffer; * * @author Adam Gibson */ -public class HalfBuffer extends BaseDataBuffer { +public class HalfBuffer extends BaseCpuDataBuffer { /** * Meant for creating another view of a buffer @@ -47,7 +50,10 @@ public class HalfBuffer extends BaseDataBuffer { */ public HalfBuffer(long length) { super(length); + } + public HalfBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); } public HalfBuffer(long length, boolean initialize) { @@ -111,18 +117,6 @@ public class HalfBuffer extends BaseDataBuffer { super(data, copy, offset); } - public HalfBuffer(ByteBuffer buffer, int length) { - super(buffer, length); - } - - public HalfBuffer(ByteBuffer buffer, int length, long offset) { - super(buffer, length, offset); - } - - public HalfBuffer(byte[] data, int length) { - super(data, length); - } - public HalfBuffer(float[] floats, boolean copy) { super(floats, copy); } diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/Int16Buffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/Int16Buffer.java similarity index 91% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/Int16Buffer.java rename to nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/Int16Buffer.java index f5cd2245f..7bf6eb969 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/Int16Buffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/Int16Buffer.java @@ -14,11 +14,14 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.linalg.api.buffer; +package org.nd4j.linalg.cpu.nativecpu.buffer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.indexer.Indexer; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import java.nio.ByteBuffer; @@ -28,7 +31,7 @@ import java.nio.ByteBuffer; * * @author Adam Gibson */ -public class Int16Buffer extends BaseDataBuffer { +public class Int16Buffer extends BaseCpuDataBuffer { /** * Meant for creating another view of a buffer @@ -47,7 +50,10 @@ public class Int16Buffer extends BaseDataBuffer { */ public Int16Buffer(long length) { super(length); + } + public Int16Buffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); } public Int16Buffer(long length, boolean initialize) { @@ -111,18 +117,6 @@ public class Int16Buffer extends BaseDataBuffer { super(data, copy, offset); } - public Int16Buffer(ByteBuffer buffer, int length) { - super(buffer, length); - } - - public Int16Buffer(ByteBuffer buffer, int length, long offset) { - super(buffer, length, offset); - } - - public Int16Buffer(byte[] data, int length) { - super(data, length); - } - public Int16Buffer(float[] floats, boolean copy) { super(floats, copy); } diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/Int8Buffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/Int8Buffer.java similarity index 91% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/Int8Buffer.java rename to nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/Int8Buffer.java index aeec19961..7f14d9ae8 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/Int8Buffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/Int8Buffer.java @@ -14,11 +14,14 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.linalg.api.buffer; +package org.nd4j.linalg.cpu.nativecpu.buffer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.indexer.Indexer; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import java.nio.ByteBuffer; @@ -28,7 +31,7 @@ import java.nio.ByteBuffer; * * @author Adam Gibson */ -public class Int8Buffer extends BaseDataBuffer { +public class Int8Buffer extends BaseCpuDataBuffer { /** * Meant for creating another view of a buffer @@ -47,7 +50,10 @@ public class Int8Buffer extends BaseDataBuffer { */ public Int8Buffer(long length) { super(length); + } + public Int8Buffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); } public Int8Buffer(long length, boolean initialize) { @@ -111,18 +117,6 @@ public class Int8Buffer extends BaseDataBuffer { super(data, copy, offset); } - public Int8Buffer(ByteBuffer buffer, int length) { - super(buffer, length); - } - - public Int8Buffer(ByteBuffer buffer, int length, long offset) { - super(buffer, length, offset); - } - - public Int8Buffer(byte[] data, int length) { - super(data, length); - } - public Int8Buffer(float[] floats, boolean copy) { super(floats, copy); } diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/IntBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/IntBuffer.java similarity index 90% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/IntBuffer.java rename to nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/IntBuffer.java index 20ec86bfd..de4282993 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/IntBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/IntBuffer.java @@ -14,11 +14,14 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.linalg.api.buffer; +package org.nd4j.linalg.cpu.nativecpu.buffer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.indexer.Indexer; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import java.nio.ByteBuffer; @@ -28,7 +31,7 @@ import java.nio.ByteBuffer; * * @author Adam Gibson */ -public class IntBuffer extends BaseDataBuffer { +public class IntBuffer extends BaseCpuDataBuffer { /** * Meant for creating another view of a buffer @@ -53,18 +56,14 @@ public class IntBuffer extends BaseDataBuffer { super(length, initialize, workspace); } + public IntBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + public IntBuffer(int[] ints, boolean copy, MemoryWorkspace workspace) { super(ints, copy, workspace); } - public IntBuffer(ByteBuffer buffer, int length, long offset) { - super(buffer, length, offset); - } - - public IntBuffer(byte[] data, int length) { - super(data, length); - } - public IntBuffer(double[] data, boolean copy) { super(data, copy); } @@ -97,10 +96,6 @@ public class IntBuffer extends BaseDataBuffer { super(underlyingBuffer, length, offset); } - public IntBuffer(ByteBuffer buffer, int length) { - super(buffer, length); - } - @Override protected DataBuffer create(long length) { return new IntBuffer(length); diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/LongBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/LongBuffer.java similarity index 83% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/LongBuffer.java rename to nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/LongBuffer.java index 42981e135..80a7f9560 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/LongBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/LongBuffer.java @@ -14,17 +14,22 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.linalg.api.buffer; +package org.nd4j.linalg.cpu.nativecpu.buffer; import lombok.NonNull; -import lombok.val; import org.bytedeco.javacpp.LongPointer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.indexer.Indexer; import org.bytedeco.javacpp.indexer.LongIndexer; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.pointers.PagedPointer; +import org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.nativeblas.NativeOpsHolder; import java.nio.ByteBuffer; @@ -33,7 +38,7 @@ import java.nio.ByteBuffer; * * @author Adam Gibson */ -public class LongBuffer extends BaseDataBuffer { +public class LongBuffer extends BaseCpuDataBuffer { /** * Meant for creating another view of a buffer @@ -58,17 +63,14 @@ public class LongBuffer extends BaseDataBuffer { super(length, initialize, workspace); } + public LongBuffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); + } + public LongBuffer(int[] ints, boolean copy, MemoryWorkspace workspace) { super(ints, copy, workspace); } - public LongBuffer(ByteBuffer buffer, int length, long offset) { - super(buffer, length, offset); - } - - public LongBuffer(byte[] data, int length) { - super(data, length); - } public LongBuffer(double[] data, boolean copy) { super(data, copy); @@ -110,10 +112,6 @@ public class LongBuffer extends BaseDataBuffer { super(underlyingBuffer, length, offset); } - public LongBuffer(ByteBuffer buffer, int length) { - super(buffer, length); - } - public LongBuffer(@NonNull Pointer hostPointer, long numberOfElements) { this.allocationMode = AllocationMode.MIXED_DATA_TYPES; this.offset = 0; @@ -124,6 +122,12 @@ public class LongBuffer extends BaseDataBuffer { this.pointer = new PagedPointer(hostPointer, numberOfElements).asLongPointer(); indexer = LongIndexer.create((LongPointer) this.pointer); + + // we still want this buffer to have native representation + ptrDataBuffer = NativeOpsHolder.getInstance().getDeviceNativeOps().allocateDataBuffer(0, DataType.INT64.toInt(), false); + NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetPrimaryBuffer(ptrDataBuffer, this.pointer, numberOfElements); + + Nd4j.getDeallocatorService().pickObject(this); } @Override diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/UInt16Buffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/UInt16Buffer.java similarity index 91% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/UInt16Buffer.java rename to nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/UInt16Buffer.java index 9d0e8d02c..d4bc705ec 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/UInt16Buffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/UInt16Buffer.java @@ -14,11 +14,14 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.linalg.api.buffer; +package org.nd4j.linalg.cpu.nativecpu.buffer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.indexer.Indexer; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import java.nio.ByteBuffer; @@ -28,7 +31,7 @@ import java.nio.ByteBuffer; * * @author Adam Gibson */ -public class UInt16Buffer extends BaseDataBuffer { +public class UInt16Buffer extends BaseCpuDataBuffer { /** * Meant for creating another view of a buffer @@ -47,7 +50,10 @@ public class UInt16Buffer extends BaseDataBuffer { */ public UInt16Buffer(long length) { super(length); + } + public UInt16Buffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); } public UInt16Buffer(long length, boolean initialize) { @@ -111,18 +117,6 @@ public class UInt16Buffer extends BaseDataBuffer { super(data, copy, offset); } - public UInt16Buffer(ByteBuffer buffer, int length) { - super(buffer, length); - } - - public UInt16Buffer(ByteBuffer buffer, int length, long offset) { - super(buffer, length, offset); - } - - public UInt16Buffer(byte[] data, int length) { - super(data, length); - } - public UInt16Buffer(float[] floats, boolean copy) { super(floats, copy); } diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/UInt32Buffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/UInt32Buffer.java similarity index 91% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/UInt32Buffer.java rename to nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/UInt32Buffer.java index 7df2621c7..b18fafc5c 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/UInt32Buffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/UInt32Buffer.java @@ -14,11 +14,14 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.linalg.api.buffer; +package org.nd4j.linalg.cpu.nativecpu.buffer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.indexer.Indexer; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import java.nio.ByteBuffer; @@ -28,7 +31,7 @@ import java.nio.ByteBuffer; * * @author Adam Gibson */ -public class UInt32Buffer extends BaseDataBuffer { +public class UInt32Buffer extends BaseCpuDataBuffer { /** * Meant for creating another view of a buffer @@ -47,7 +50,10 @@ public class UInt32Buffer extends BaseDataBuffer { */ public UInt32Buffer(long length) { super(length); + } + public UInt32Buffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); } public UInt32Buffer(long length, boolean initialize) { @@ -111,18 +117,6 @@ public class UInt32Buffer extends BaseDataBuffer { super(data, copy, offset); } - public UInt32Buffer(ByteBuffer buffer, int length) { - super(buffer, length); - } - - public UInt32Buffer(ByteBuffer buffer, int length, long offset) { - super(buffer, length, offset); - } - - public UInt32Buffer(byte[] data, int length) { - super(data, length); - } - public UInt32Buffer(float[] floats, boolean copy) { super(floats, copy); } diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/UInt64Buffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/UInt64Buffer.java similarity index 91% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/UInt64Buffer.java rename to nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/UInt64Buffer.java index 15af50dd1..84adf29b6 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/UInt64Buffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/UInt64Buffer.java @@ -14,11 +14,14 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.linalg.api.buffer; +package org.nd4j.linalg.cpu.nativecpu.buffer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.indexer.Indexer; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import java.nio.ByteBuffer; @@ -28,7 +31,7 @@ import java.nio.ByteBuffer; * * @author Adam Gibson */ -public class UInt64Buffer extends BaseDataBuffer { +public class UInt64Buffer extends BaseCpuDataBuffer { /** * Meant for creating another view of a buffer @@ -47,7 +50,10 @@ public class UInt64Buffer extends BaseDataBuffer { */ public UInt64Buffer(long length) { super(length); + } + public UInt64Buffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); } public UInt64Buffer(long length, boolean initialize) { @@ -111,18 +117,6 @@ public class UInt64Buffer extends BaseDataBuffer { super(data, copy, offset); } - public UInt64Buffer(ByteBuffer buffer, int length) { - super(buffer, length); - } - - public UInt64Buffer(ByteBuffer buffer, int length, long offset) { - super(buffer, length, offset); - } - - public UInt64Buffer(byte[] data, int length) { - super(data, length); - } - public UInt64Buffer(float[] floats, boolean copy) { super(floats, copy); } diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/UInt8Buffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/UInt8Buffer.java similarity index 91% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/UInt8Buffer.java rename to nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/UInt8Buffer.java index 56f311e9e..d207d370a 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/UInt8Buffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/UInt8Buffer.java @@ -14,11 +14,14 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.linalg.api.buffer; +package org.nd4j.linalg.cpu.nativecpu.buffer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.indexer.Indexer; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import java.nio.ByteBuffer; @@ -28,7 +31,7 @@ import java.nio.ByteBuffer; * * @author Adam Gibson */ -public class UInt8Buffer extends BaseDataBuffer { +public class UInt8Buffer extends BaseCpuDataBuffer { /** * Meant for creating another view of a buffer @@ -47,7 +50,10 @@ public class UInt8Buffer extends BaseDataBuffer { */ public UInt8Buffer(long length) { super(length); + } + public UInt8Buffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); } public UInt8Buffer(long length, boolean initialize) { @@ -111,18 +117,6 @@ public class UInt8Buffer extends BaseDataBuffer { super(data, copy, offset); } - public UInt8Buffer(ByteBuffer buffer, int length) { - super(buffer, length); - } - - public UInt8Buffer(ByteBuffer buffer, int length, long offset) { - super(buffer, length, offset); - } - - public UInt8Buffer(byte[] data, int length) { - super(data, length); - } - public UInt8Buffer(float[] floats, boolean copy) { super(floats, copy); } diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/Utf8Buffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/Utf8Buffer.java similarity index 89% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/Utf8Buffer.java rename to nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/Utf8Buffer.java index e2cdc9c2f..3f33cc044 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/Utf8Buffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/Utf8Buffer.java @@ -14,7 +14,7 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.linalg.api.buffer; +package org.nd4j.linalg.cpu.nativecpu.buffer; import lombok.Getter; @@ -23,11 +23,11 @@ import lombok.val; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.LongPointer; import org.bytedeco.javacpp.Pointer; -import org.bytedeco.javacpp.indexer.ByteIndexer; import org.bytedeco.javacpp.indexer.Indexer; -import org.bytedeco.javacpp.indexer.LongIndexer; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; -import org.nd4j.linalg.api.memory.pointers.PagedPointer; import java.io.UnsupportedEncodingException; import java.nio.ByteBuffer; @@ -39,7 +39,7 @@ import java.util.Collection; * * @author Adam Gibson */ -public class Utf8Buffer extends BaseDataBuffer { +public class Utf8Buffer extends BaseCpuDataBuffer { protected Collection references = new ArrayList<>(); @@ -62,21 +62,30 @@ public class Utf8Buffer extends BaseDataBuffer { } public Utf8Buffer(long length, boolean initialize) { - super(length, initialize); + /** + * Special case: we're creating empty buffer for length strings, each of 0 chars + */ + super((length + 1) * 8, true); + numWords = length; } public Utf8Buffer(long length, boolean initialize, MemoryWorkspace workspace) { - super(length, initialize, workspace); + /** + * Special case: we're creating empty buffer for length strings, each of 0 chars + */ + + super((length + 1) * 8, true, workspace); + numWords = length; + } + + public Utf8Buffer(ByteBuffer buffer, DataType dataType, long length, long offset) { + super(buffer, dataType, length, offset); } public Utf8Buffer(int[] ints, boolean copy, MemoryWorkspace workspace) { super(ints, copy, workspace); } - public Utf8Buffer(ByteBuffer buffer, int length, long offset) { - super(buffer, length, offset); - } - public Utf8Buffer(byte[] data, long numWords) { super(data.length, false); @@ -155,10 +164,6 @@ public class Utf8Buffer extends BaseDataBuffer { headerPointer.put(cnt, currentLength); } - public Utf8Buffer(ByteBuffer buffer, int length) { - super(buffer, length); - } - public String getString(long index) { if (index > numWords) throw new IllegalArgumentException("Requested index [" + index + "] is above actual number of words stored: [" + numWords + "]"); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java index 72f3e4553..fce391a05 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java @@ -24,6 +24,7 @@ import org.bytedeco.javacpp.Pointer; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseOpContext; import org.nd4j.linalg.api.ops.OpContext; +import org.nd4j.linalg.cpu.nativecpu.buffer.BaseCpuDataBuffer; import org.nd4j.linalg.primitives.Pair; import org.nd4j.nativeblas.NativeOps; import org.nd4j.nativeblas.NativeOpsHolder; @@ -84,14 +85,16 @@ public class CpuOpContext extends BaseOpContext implements OpContext { @Override public void setInputArray(int index, @NonNull INDArray array) { - nativeOps.setGraphContextInputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), null, null); + //nativeOps.setGraphContextInputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), null, null); + nativeOps.setGraphContextInputBuffer(context, index, array.isEmpty() ? null : ((BaseCpuDataBuffer) array.data()).getOpaqueDataBuffer(), array.shapeInfoDataBuffer().addressPointer(), null); super.setInputArray(index, array); } @Override public void setOutputArray(int index, @NonNull INDArray array) { - nativeOps.setGraphContextOutputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), null, null); + //nativeOps.setGraphContextOutputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), null, null); + nativeOps.setGraphContextOutputBuffer(context, index, array.isEmpty() ? null : ((BaseCpuDataBuffer) array.data()).getOpaqueDataBuffer(), array.shapeInfoDataBuffer().addressPointer(), null); super.setOutputArray(index, array); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java index c5e520ebc..dfd81c80b 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java @@ -25,7 +25,6 @@ import lombok.val; import org.bytedeco.javacpp.*; import org.bytedeco.javacpp.indexer.LongIndexer; import org.nd4j.autodiff.functions.DifferentialFunction; -import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper; import org.nd4j.base.Preconditions; import org.nd4j.compression.impl.AbstractCompressor; @@ -57,6 +56,9 @@ import org.nd4j.linalg.compression.CompressionDescriptor; import org.nd4j.linalg.compression.CompressionType; import org.nd4j.linalg.compression.ThresholdCompression; import org.nd4j.linalg.cpu.nativecpu.CpuTADManager; +import org.nd4j.linalg.cpu.nativecpu.buffer.BaseCpuDataBuffer; +import org.nd4j.linalg.cpu.nativecpu.buffer.LongBuffer; +import org.nd4j.linalg.cpu.nativecpu.buffer.Utf8Buffer; import org.nd4j.linalg.cpu.nativecpu.rng.CpuNativeRandom; import org.nd4j.linalg.exception.ND4JIllegalArgumentException; import org.nd4j.linalg.exception.ND4JIllegalStateException; @@ -67,15 +69,7 @@ import org.nd4j.linalg.primitives.AtomicBoolean; import org.nd4j.linalg.primitives.Optional; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.util.ArrayUtil; -import org.nd4j.nativeblas.LongPointerWrapper; -import org.nd4j.nativeblas.NativeOps; -import org.nd4j.nativeblas.NativeOpsHolder; -import org.nd4j.nativeblas.Nd4jCpu; -import org.nd4j.nativeblas.OpaqueConstantDataBuffer; -import org.nd4j.nativeblas.OpaqueShapeList; -import org.nd4j.nativeblas.OpaqueTadPack; -import org.nd4j.nativeblas.OpaqueVariable; -import org.nd4j.nativeblas.OpaqueVariablesSet; +import org.nd4j.nativeblas.*; import java.util.*; @@ -209,29 +203,20 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { long st = profilingConfigurableHookIn(op, tadBuffers.getFirst()); - Pointer x = op.x().data().addressPointer(); - Pointer z = op.z().data().addressPointer(); + val x = ((BaseCpuDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val z = ((BaseCpuDataBuffer) op.z().data()).getOpaqueDataBuffer(); if (op.z().isScalar()) { loop.execIndexReduceScalar(dummy, op.opNum(), - op.x().data().addressPointer(), - (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, - null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.x().dataType()), - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null); + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null); } else { loop.execIndexReduce(dummy, op.opNum(), - x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.x().dataType()), - z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, - op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - null, - null); + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); } if (loop.lastErrorCode() != 0) @@ -398,30 +383,26 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { * This gives us a pointer which is passed around in libnd4j. */ Pointer dimensionAddress = constantHandler.getConstantBuffer(dimension, DataType.INT).addressPointer(); + val x = ((BaseCpuDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val z = ((BaseCpuDataBuffer) op.z().data()).getOpaqueDataBuffer(); if (op instanceof Variance) { if (ret.isScalar()) { loop.execSummaryStatsScalar(null, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.z().dataType()), - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, ((Variance) op).isBiasCorrected()); } else { Variance var = (Variance) op; try { loop.execSummaryStatsTad(null, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, - getPointerForExtraArgs(op, op.z().dataType()), - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, - op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - null, - null, - var.isBiasCorrected(), null, null);} catch (Throwable t){ + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + getPointerForExtraArgs(op, op.z().dataType()), + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, + var.isBiasCorrected(), null, null); + } catch (Throwable t){ String str = opInfoString(op, Optional.of(dimension)); throw new RuntimeException("Native AccumulationOp execution (double) failed: " + str, t); } @@ -430,24 +411,17 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } //pairwise reduction like similarity of two arrays else if (op.y() != null && op.getOpType() == Op.Type.REDUCE3) { + val y = ((BaseCpuDataBuffer) op.y().data()).getOpaqueDataBuffer(); if (op.isComplexAccumulation()) { try { loop.execReduce3All(null, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.z().dataType()), - op.y().data().addressPointer(), (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), - null, null, - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, - op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - null, - null, - (LongPointer) tadBuffers.getFirst().addressPointer(), - new LongPointerWrapper(tadBuffers.getSecond().addressPointer()), - (LongPointer) yTadBuffers.getFirst().addressPointer(), - new LongPointerWrapper(yTadBuffers.getSecond().addressPointer()) + y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, + (LongPointer) tadBuffers.getFirst().addressPointer(), new LongPointerWrapper(tadBuffers.getSecond().addressPointer()), + (LongPointer) yTadBuffers.getFirst().addressPointer(), new LongPointerWrapper(yTadBuffers.getSecond().addressPointer()) ); } catch (Throwable t){ String str = opInfoString(op, Optional.of(dimension)); @@ -455,27 +429,18 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } } else if (ret.isScalar()) { loop.execReduce3Scalar(null, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.z().dataType()), - op.y().data().addressPointer(), (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), - null, null, - ret.data().addressPointer(), (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), - null, null); + y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), null); } else { try { loop.execReduce3Tad(null, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.z().dataType()), - op.y().data().addressPointer(), (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), - null, null, - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, - op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - null, - null, + y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, null, null, null, null); } catch (Throwable t){ String str = opInfoString(op, Optional.of(dimension)); @@ -488,35 +453,27 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { switch (op.getOpType()) { case REDUCE_FLOAT: loop.execReduceFloat(null, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.z().dataType()), - ret.data().addressPointer(), (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), - null, null); + z, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_BOOL: loop.execReduceBool(null, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.x().dataType()), - ret.data().addressPointer(), (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), - null, null); + z, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_SAME: loop.execReduceSame(null, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.x().dataType()), - ret.data().addressPointer(), (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), - null, null); + z, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_LONG: loop.execReduceLong(null, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.x().dataType()), - ret.data().addressPointer(), (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), - null, null); + z, (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), null); break; default: throw new UnsupportedOperationException("Unsupported op used in reduce: "+ op.getOpType()); @@ -525,51 +482,34 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { switch (op.getOpType()) { case REDUCE_FLOAT: loop.execReduceFloat2(null, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.z().dataType()), - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, - op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - null, - null); + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_LONG: loop.execReduceLong2(null, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.x().dataType()), - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, - op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - null, - null); + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), + (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_SAME: loop.execReduceSame2(null, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.z().dataType()), - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, - op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - null, - null); + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), + (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case REDUCE_BOOL: loop.execReduceBool2(null, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.x().dataType()), - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, - op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - null, - null); + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), + (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; default: throw new UnsupportedOperationException("Unsupported op used in reduce: "+ op.getOpType()); @@ -621,39 +561,28 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { if (extraz.get() == null) extraz.set(new PointerPointer(32)); - //PointerPointer dummy = extraz.get().put(hostTadShapeInfo, hostTadOffsets, devTadShapeInfoZ, devTadOffsetsZ); - + val x = ((BaseCpuDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val y = ((BaseCpuDataBuffer) op.y().data()).getOpaqueDataBuffer(); + val z = ((BaseCpuDataBuffer) op.z().data()).getOpaqueDataBuffer(); switch (op.getOpType()) { case SCALAR: loop.execScalarTad(null, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, - op.y().data().addressPointer(), (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.z().dataType()), - op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - null, - null, + ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(),null, (LongPointer) hostTadShapeInfo, (LongPointer) hostTadOffsets, (LongPointer) devTadShapeInfoZ, (LongPointer) devTadOffsetsZ); break; case SCALAR_BOOL: loop.execScalarBoolTad(null, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, - op.y().data().addressPointer(), (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.z().dataType()), - op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - null, - null, + ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, (LongPointer) hostTadShapeInfo, (LongPointer) hostTadOffsets, (LongPointer) devTadShapeInfoZ, (LongPointer) devTadOffsetsZ); break; @@ -694,28 +623,26 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { return op.z(); } + val x = ((BaseCpuDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val scalar = ((BaseCpuDataBuffer) op.scalar().data()).getOpaqueDataBuffer(); + val z = ((BaseCpuDataBuffer) op.z().data()).getOpaqueDataBuffer(); + switch (op.getOpType()) { case SCALAR: loop.execScalar(null, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, - op.scalar().data().addressPointer(), (LongPointer) op.scalar().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + scalar, (LongPointer) op.scalar().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.z().dataType())); break; case SCALAR_BOOL: loop.execScalarBool(null, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, - op.scalar().data().addressPointer(), (LongPointer) op.scalar().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + scalar, (LongPointer) op.scalar().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.x().dataType())); break; default: @@ -820,6 +747,10 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { "; y: " + op.y().length() + ", shape " + Arrays.toString(op.y().shape()) + "; z: " + op.z().length() + ", shape " + Arrays.toString(op.z().shape())); + val x = ((BaseCpuDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val y = ((BaseCpuDataBuffer) op.y().data()).getOpaqueDataBuffer(); + val z = ((BaseCpuDataBuffer) op.z().data()).getOpaqueDataBuffer(); + switch (op.getOpType()) { case TRANSFORM_ANY: case TRANSFORM_FLOAT: @@ -829,54 +760,46 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { Preconditions.checkArgument(op.x().dataType() == op.y().dataType() || op.y().dataType() == DataType.BOOL, "Op.X and Op.Y must have the same data type, but got " + op.x().dataType() + " vs " + op.y().dataType()); loop.execPairwiseTransform(dummy, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, - op.y().data().addressPointer(), (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), - null, null, - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.z().dataType())); break; case TRANSFORM_BOOL: case PAIRWISE_BOOL: loop.execPairwiseTransformBool(dummy, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, - op.y().data().addressPointer(), (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), - null, null, - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, getPointerForExtraArgs(op, op.x().dataType())); break; } } else { if (op.z() == null) - op.setZ(Nd4j.create(op.resultType(), op.x().shape())); + op.setZ(Nd4j.createUninitialized(op.resultType(), op.x().shape())); op.validateDataTypes(experimentalMode.get()); + val x = ((BaseCpuDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val z = ((BaseCpuDataBuffer) op.z().data()).getOpaqueDataBuffer(); switch (op.getOpType()) { case TRANSFORM_FLOAT: { val xtraz = getPointerForExtraArgs(op, op.z().dataType()); loop.execTransformFloat(dummy, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, - xtraz); + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), + null, xtraz); break; } case TRANSFORM_STRICT: { val xtraz = getPointerForExtraArgs(op, op.z().dataType()); loop.execTransformStrict(dummy, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, xtraz); break; } @@ -884,10 +807,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { val xtraz = getPointerForExtraArgs(op, op.z().dataType()); loop.execTransformSame(dummy, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, xtraz); break; } @@ -896,10 +817,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { val opNum = op.opNum(); loop.execTransformAny(dummy, opNum, - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, xtraz); break; } @@ -908,10 +827,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { val opNum = op.opNum(); loop.execTransformBool(dummy, opNum, - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, xtraz); break; } @@ -968,34 +885,25 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { Pointer dimensionAddress = constantHandler.getConstantBuffer(dimension, DataType.INT).addressPointer(); + val x = ((BaseCpuDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val y = ((BaseCpuDataBuffer) op.y().data()).getOpaqueDataBuffer(); + val z = ((BaseCpuDataBuffer) op.z().data()).getOpaqueDataBuffer(); switch (op.getOpType()) { case BROADCAST: loop.execBroadcast(dummy, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, - op.y().data().addressPointer(), (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), - null, null, - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, - op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - null, - null); + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, + ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; case BROADCAST_BOOL: loop.execBroadcastBool(dummy, op.opNum(), - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, - op.y().data().addressPointer(), (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), - null, null, - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, null, - op.dimensions().data().addressPointer(), - (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), - null, - null); + ((BaseCpuDataBuffer) op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null); break; default: throw new UnsupportedOperationException("Unknown operation type: [" + op.getOpType() + "]"); @@ -1304,29 +1212,27 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { Preconditions.checkArgument(op.z().isR(), "Op.Z must have one of floating point types"); + val x = op.x() == null ? null : ((BaseCpuDataBuffer) op.x().data()).getOpaqueDataBuffer(); + val y = op.y() == null ? null : ((BaseCpuDataBuffer) op.y().data()).getOpaqueDataBuffer(); + val z = op.z() == null ? null : ((BaseCpuDataBuffer) op.z().data()).getOpaqueDataBuffer(); + if (op.x() != null && op.y() != null && op.z() != null) { // triple arg call loop.execRandom3(null, op.opNum(), rng.getStatePointer(), // rng state ptr - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, - op.y().data().addressPointer(), (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), - null, null, - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + y, (LongPointer) op.y().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, op.extraArgsDataBuff(op.z().dataType()).addressPointer()); } else if (op.x() != null && op.z() != null) { //double arg call loop.execRandom2(null, op.opNum(), rng.getStatePointer(), // rng state ptr - op.x().data().addressPointer(), (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), - null, null, - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, + x, (LongPointer) op.x().shapeInfoDataBuffer().addressPointer(), null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, op.extraArgsDataBuff(op.z().dataType()).addressPointer()); } else { // single arg call loop.execRandom(null, op.opNum(), rng.getStatePointer(), // rng state ptr - op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), - null, null, + z, (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, op.extraArgsDataBuff(op.z().dataType()).addressPointer()); } @@ -1706,6 +1612,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { throw e; } catch (Exception e) { throw new ND4JIllegalStateException("Op name " + op.opName() + " - no output arrays were provided and calculateOutputShape failed to execute", e); + //throw new RuntimeException(e); } } @@ -1733,6 +1640,17 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { val result = exec(op, context); val states = context.getRngStates(); + // check if input & output needs update + for (val in:op.inputArguments()) { + if (!in.isEmpty()) + ((BaseCpuDataBuffer) in.data()).actualizePointerAndIndexer(); + } + + for (val out:op.outputArguments()) { + if (!out.isEmpty()) + ((BaseCpuDataBuffer) out.data()).actualizePointerAndIndexer(); + } + // pulling states back Nd4j.getRandom().setStates(states.getFirst(), states.getSecond()); @@ -1815,10 +1733,10 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } catch (Throwable t){ StringBuilder sb = new StringBuilder(); sb.append("Inputs: [("); - for( int i=0; i 0) sb.append("), ("); - sb.append(Shape.shapeToStringShort(inputArgs[i])); + sb.append(Shape.shapeToStringShort(inputArgs.get(i))); } sb.append(")]"); if(op instanceof DifferentialFunction && ((DifferentialFunction)op).getSameDiff() != null){ @@ -1979,7 +1897,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } @Override - public String getString(Utf8Buffer buffer, long index) { + public String getString(DataBuffer buffer, long index) { + Preconditions.checkArgument(buffer instanceof Utf8Buffer, "Expected Utf8Buffer"); + val addr = ((LongIndexer) buffer.indexer()).get(index); val ptr = new PagedPointer(addr); val str = new Nd4jCpu.utf8string(ptr); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index f2da4dc19..ba5cb74a4 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -110,6 +110,74 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper { } } +@Name("std::vector") public static class ConstNDArrayVector extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public ConstNDArrayVector(Pointer p) { super(p); } + public ConstNDArrayVector(NDArray value) { this(1); put(0, value); } + public ConstNDArrayVector(NDArray ... array) { this(array.length); put(array); } + public ConstNDArrayVector() { allocate(); } + public ConstNDArrayVector(long n) { allocate(n); } + private native void allocate(); + private native void allocate(@Cast("size_t") long n); + public native @Name("operator=") @ByRef ConstNDArrayVector put(@ByRef ConstNDArrayVector x); + + public boolean empty() { return size() == 0; } + public native long size(); + public void clear() { resize(0); } + public native void resize(@Cast("size_t") long n); + + @Index(function = "at") public native @Const NDArray get(@Cast("size_t") long i); + public native ConstNDArrayVector put(@Cast("size_t") long i, NDArray value); + + public native @ByVal Iterator insert(@ByVal Iterator pos, @Const NDArray value); + public native @ByVal Iterator erase(@ByVal Iterator pos); + public native @ByVal Iterator begin(); + public native @ByVal Iterator end(); + @NoOffset @Name("iterator") public static class Iterator extends Pointer { + public Iterator(Pointer p) { super(p); } + public Iterator() { } + + public native @Name("operator++") @ByRef Iterator increment(); + public native @Name("operator==") boolean equals(@ByRef Iterator it); + public native @Name("operator*") @Const NDArray get(); + } + + public NDArray[] get() { + NDArray[] array = new NDArray[size() < Integer.MAX_VALUE ? (int)size() : Integer.MAX_VALUE]; + for (int i = 0; i < array.length; i++) { + array[i] = get(i); + } + return array; + } + @Override public String toString() { + return java.util.Arrays.toString(get()); + } + + public NDArray pop_back() { + long size = size(); + NDArray value = get(size - 1); + resize(size - 1); + return value; + } + public ConstNDArrayVector push_back(NDArray value) { + long size = size(); + resize(size + 1); + return put(size, value); + } + public ConstNDArrayVector put(NDArray value) { + if (size() != 1) { resize(1); } + return put(0, value); + } + public ConstNDArrayVector put(NDArray ... array) { + if (size() != array.length) { resize(array.length); } + for (int i = 0; i < array.length; i++) { + put(i, array[i]); + } + return this; + } +} + @Name("std::vector") public static class NDArrayVector extends Pointer { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ @@ -261,12 +329,167 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper { QINT16 = 16, BFLOAT16 = 17, UTF8 = 50, + UTF16 = 51, + UTF32 = 52, ANY = 100, AUTO = 200; // #endif +// Parsed from array/DataBuffer.h + +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +// #ifndef DEV_TESTS_DATABUFFER_H +// #define DEV_TESTS_DATABUFFER_H + +// #include +// #include +// #include +// #include +// #include +// #include +// #include + +@Namespace("nd4j") @NoOffset public static class DataBuffer extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public DataBuffer(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public DataBuffer(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public DataBuffer position(long position) { + return (DataBuffer)super.position(position); + } + + + public DataBuffer(Pointer primary, Pointer special, + @Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType, + @Cast("const bool") boolean isOwnerPrimary/*=false*/, @Cast("const bool") boolean isOwnerSpecial/*=false*/, + Workspace workspace/*=nullptr*/) { super((Pointer)null); allocate(primary, special, lenInBytes, dataType, isOwnerPrimary, isOwnerSpecial, workspace); } + private native void allocate(Pointer primary, Pointer special, + @Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType, + @Cast("const bool") boolean isOwnerPrimary/*=false*/, @Cast("const bool") boolean isOwnerSpecial/*=false*/, + Workspace workspace/*=nullptr*/); + public DataBuffer(Pointer primary, Pointer special, + @Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType) { super((Pointer)null); allocate(primary, special, lenInBytes, dataType); } + private native void allocate(Pointer primary, Pointer special, + @Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType); + + public DataBuffer(Pointer primary, + @Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType, + @Cast("const bool") boolean isOwnerPrimary/*=false*/, + Workspace workspace/*=nullptr*/) { super((Pointer)null); allocate(primary, lenInBytes, dataType, isOwnerPrimary, workspace); } + private native void allocate(Pointer primary, + @Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType, + @Cast("const bool") boolean isOwnerPrimary/*=false*/, + Workspace workspace/*=nullptr*/); + public DataBuffer(Pointer primary, + @Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType) { super((Pointer)null); allocate(primary, lenInBytes, dataType); } + private native void allocate(Pointer primary, + @Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType); + + public DataBuffer(@Const Pointer hostBuffer, + @Cast("const nd4j::DataType") int dataType, @Cast("const size_t") long lenInBytes, + Workspace workspace/*=nullptr*/) { super((Pointer)null); allocate(hostBuffer, dataType, lenInBytes, workspace); } + private native void allocate(@Const Pointer hostBuffer, + @Cast("const nd4j::DataType") int dataType, @Cast("const size_t") long lenInBytes, + Workspace workspace/*=nullptr*/); + public DataBuffer(@Const Pointer hostBuffer, + @Cast("const nd4j::DataType") int dataType, @Cast("const size_t") long lenInBytes) { super((Pointer)null); allocate(hostBuffer, dataType, lenInBytes); } + private native void allocate(@Const Pointer hostBuffer, + @Cast("const nd4j::DataType") int dataType, @Cast("const size_t") long lenInBytes); + + public DataBuffer(@Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType, Workspace workspace/*=nullptr*/, @Cast("const bool") boolean allocBoth/*=false*/) { super((Pointer)null); allocate(lenInBytes, dataType, workspace, allocBoth); } + private native void allocate(@Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType, Workspace workspace/*=nullptr*/, @Cast("const bool") boolean allocBoth/*=false*/); + public DataBuffer(@Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType) { super((Pointer)null); allocate(lenInBytes, dataType); } + private native void allocate(@Cast("const size_t") long lenInBytes, @Cast("const nd4j::DataType") int dataType); + + public DataBuffer(@Const @ByRef DataBuffer other) { super((Pointer)null); allocate(other); } + private native void allocate(@Const @ByRef DataBuffer other); + public DataBuffer() { super((Pointer)null); allocate(); } + private native void allocate(); + + public native @ByRef @Name("operator =") DataBuffer put(@Const @ByRef DataBuffer other); + + public native @Cast("nd4j::DataType") int getDataType(); + public native void setDataType(@Cast("nd4j::DataType") int dataType); + public native @Cast("size_t") long getLenInBytes(); + + public native Pointer primary(); + public native Pointer special(); + + public native void allocatePrimary(); + public native void allocateSpecial(); + + public native void writePrimary(); + public native void writeSpecial(); + public native void readPrimary(); + public native void readSpecial(); + public native @Cast("bool") boolean isPrimaryActual(); + public native @Cast("bool") boolean isSpecialActual(); + + public native void expand(@Cast("const uint64_t") long size); + + public native int deviceId(); + public native void setDeviceId(int deviceId); + public native void migrate(); + + public native void syncToPrimary(@Const LaunchContext context, @Cast("const bool") boolean forceSync/*=false*/); + public native void syncToPrimary(@Const LaunchContext context); + public native void syncToSpecial(@Cast("const bool") boolean forceSync/*=false*/); + public native void syncToSpecial(); + + public native void setToZeroBuffers(@Cast("const bool") boolean both/*=false*/); + public native void setToZeroBuffers(); + + public native void copyBufferFrom(@Const @ByRef DataBuffer other, @Cast("size_t") long sizeToCopyinBytes/*=0*/, @Cast("const Nd4jLong") long offsetThis/*=0*/, @Cast("const Nd4jLong") long offsetOther/*=0*/); + public native void copyBufferFrom(@Const @ByRef DataBuffer other); + + public static native void memcpy(@Const @ByRef DataBuffer dst, @Const @ByRef DataBuffer src); + + public native void setPrimaryBuffer(Pointer buffer, @Cast("size_t") long length); + public native void setSpecialBuffer(Pointer buffer, @Cast("size_t") long length); + + /** + * This method deletes buffers, if we're owners + */ + public native @Name("close") void _close(); +} +///// IMLEMENTATION OF INLINE METHODS ///// + +//////////////////////////////////////////////////////////////////////// + + +//////////////////////////////////////////////////////////////////////// + + + + + +// #endif //DEV_TESTS_DATABUFFER_H + + // Parsed from array/ConstantDataBuffer.h /******************************************************************************* @@ -756,6 +979,7 @@ bool verbose = false; // #include // #include // #include +// #include // #include // #include // #include @@ -804,25 +1028,19 @@ public native void setTADThreshold(int num); */ public native void execIndexReduceScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo); public native void execIndexReduceScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); public native void execIndexReduceScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo); /** * @@ -837,31 +1055,22 @@ public native void execIndexReduceScalar(@Cast("Nd4jPointer*") PointerPointer ex */ public native void execIndexReduce(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape); public native void execIndexReduce(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape); public native void execIndexReduce(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape); /** * @@ -878,74 +1087,50 @@ public native void execIndexReduce(@Cast("Nd4jPointer*") PointerPointer extraPoi public native void execBroadcast( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape); + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, @Cast("Nd4jLong*") LongPointer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape); public native void execBroadcast( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape); + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape); public native void execBroadcast( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape); + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeInfo, @Cast("Nd4jLong*") long[] dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape); public native void execBroadcastBool( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, @Cast("Nd4jLong*") LongPointer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape); + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape); public native void execBroadcastBool( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape); + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape); public native void execBroadcastBool( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeInfo, @Cast("Nd4jLong*") long[] dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape); + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape); /** * @@ -962,63 +1147,45 @@ public native void execBroadcastBool( public native void execPairwiseTransform( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, @Cast("Nd4jLong*") LongPointer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); public native void execPairwiseTransform( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); public native void execPairwiseTransform( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeInfo, @Cast("Nd4jLong*") long[] dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); public native void execPairwiseTransformBool( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, @Cast("Nd4jLong*") LongPointer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); public native void execPairwiseTransformBool( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); public native void execPairwiseTransformBool( @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeInfo, @Cast("Nd4jLong*") long[] dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); /** @@ -1032,92 +1199,68 @@ public native void execPairwiseTransformBool( */ public native void execReduceFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo); public native void execReduceFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); public native void execReduceFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo); public native void execReduceSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo); public native void execReduceSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); public native void execReduceSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo); public native void execReduceBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo); public native void execReduceBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); public native void execReduceBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo); public native void execReduceLong(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo); public native void execReduceLong(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); public native void execReduceLong(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo); /** * @@ -1130,118 +1273,82 @@ public native void execReduceLong(@Cast("Nd4jPointer*") PointerPointer extraPoin */ public native void execReduceFloat2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape); public native void execReduceFloat2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape); public native void execReduceFloat2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape); public native void execReduceSame2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape); public native void execReduceSame2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape); public native void execReduceSame2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape); public native void execReduceBool2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape); public native void execReduceBool2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape); public native void execReduceBool2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape); public native void execReduceLong2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape); public native void execReduceLong2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape); public native void execReduceLong2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape); + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape); /** * @@ -1256,31 +1363,22 @@ public native void execReduceLong2(@Cast("Nd4jPointer*") PointerPointer extraPoi */ public native void execReduce3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo); + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, @Cast("Nd4jLong*") LongPointer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo); public native void execReduce3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); public native void execReduce3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo); + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeInfo, @Cast("Nd4jLong*") long[] dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo); /** * @@ -1293,31 +1391,22 @@ public native void execReduce3(@Cast("Nd4jPointer*") PointerPointer extraPointer */ public native void execReduce3Scalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo); + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, @Cast("Nd4jLong*") LongPointer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo); public native void execReduce3Scalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo); public native void execReduce3Scalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo); + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeInfo, @Cast("Nd4jLong*") long[] dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo); /** * * @param opNum @@ -1333,82 +1422,58 @@ public native void execReduce3Scalar(@Cast("Nd4jPointer*") PointerPointer extraP */ public native void execReduce3Tad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, @Cast("Nd4jLong*") LongPointer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape, @Cast("Nd4jLong*") LongPointer tadOnlyShapeInfo, @Cast("Nd4jLong*") LongPointer tadOffsets, @Cast("Nd4jLong*") LongPointer yTadOnlyShapeInfo, @Cast("Nd4jLong*") LongPointer yTadOffsets); public native void execReduce3Tad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape, @Cast("Nd4jLong*") LongBuffer tadOnlyShapeInfo, @Cast("Nd4jLong*") LongBuffer tadOffsets, @Cast("Nd4jLong*") LongBuffer yTadOnlyShapeInfo, @Cast("Nd4jLong*") LongBuffer yTadOffsets); public native void execReduce3Tad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeInfo, @Cast("Nd4jLong*") long[] dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape, @Cast("Nd4jLong*") long[] tadOnlyShapeInfo, @Cast("Nd4jLong*") long[] tadOffsets, @Cast("Nd4jLong*") long[] yTadOnlyShapeInfo, @Cast("Nd4jLong*") long[] yTadOffsets); public native void execReduce3All(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeInfo, @Cast("Nd4jLong*") LongPointer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape, @Cast("Nd4jLong*") LongPointer xTadShapeInfo, @Cast("Nd4jLong*") LongPointer xOffsets, @Cast("Nd4jLong*") LongPointer yTadShapeInfo, @Cast("Nd4jLong*") LongPointer yOffsets); public native void execReduce3All(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeInfo, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape, @Cast("Nd4jLong*") LongBuffer xTadShapeInfo, @Cast("Nd4jLong*") LongBuffer xOffsets, @Cast("Nd4jLong*") LongBuffer yTadShapeInfo, @Cast("Nd4jLong*") LongBuffer yOffsets); public native void execReduce3All(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParamsVals, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeInfo, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeInfo, @Cast("Nd4jLong*") long[] dYShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape, @Cast("Nd4jLong*") long[] xTadShapeInfo, @Cast("Nd4jLong*") long[] xOffsets, @Cast("Nd4jLong*") long[] yTadShapeInfo, @Cast("Nd4jLong*") long[] yOffsets); @@ -1425,58 +1490,40 @@ public native void execReduce3All(@Cast("Nd4jPointer*") PointerPointer extraPoin */ public native void execScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hScalar, @Cast("Nd4jLong*") LongPointer hSscalarShapeInfo, - Pointer dScalar, @Cast("Nd4jLong*") LongPointer dSscalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbScalar, @Cast("Nd4jLong*") LongPointer hSscalarShapeInfo, @Cast("Nd4jLong*") LongPointer dSscalarShapeInfo, Pointer extraParams); public native void execScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hScalar, @Cast("Nd4jLong*") LongBuffer hSscalarShapeInfo, - Pointer dScalar, @Cast("Nd4jLong*") LongBuffer dSscalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbScalar, @Cast("Nd4jLong*") LongBuffer hSscalarShapeInfo, @Cast("Nd4jLong*") LongBuffer dSscalarShapeInfo, Pointer extraParams); public native void execScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hScalar, @Cast("Nd4jLong*") long[] hSscalarShapeInfo, - Pointer dScalar, @Cast("Nd4jLong*") long[] dSscalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbScalar, @Cast("Nd4jLong*") long[] hSscalarShapeInfo, @Cast("Nd4jLong*") long[] dSscalarShapeInfo, Pointer extraParams); public native void execScalarBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hScalar, @Cast("Nd4jLong*") LongPointer hSscalarShapeInfo, - Pointer dScalar, @Cast("Nd4jLong*") LongPointer dSscalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbScalar, @Cast("Nd4jLong*") LongPointer hSscalarShapeInfo, @Cast("Nd4jLong*") LongPointer dSscalarShapeInfo, Pointer extraParams); public native void execScalarBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hScalar, @Cast("Nd4jLong*") LongBuffer hSscalarShapeInfo, - Pointer dScalar, @Cast("Nd4jLong*") LongBuffer dSscalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbScalar, @Cast("Nd4jLong*") LongBuffer hSscalarShapeInfo, @Cast("Nd4jLong*") LongBuffer dSscalarShapeInfo, Pointer extraParams); public native void execScalarBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hScalar, @Cast("Nd4jLong*") long[] hSscalarShapeInfo, - Pointer dScalar, @Cast("Nd4jLong*") long[] dSscalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbScalar, @Cast("Nd4jLong*") long[] hSscalarShapeInfo, @Cast("Nd4jLong*") long[] dSscalarShapeInfo, Pointer extraParams); /** @@ -1488,27 +1535,21 @@ public native void execScalarBool(@Cast("Nd4jPointer*") PointerPointer extraPoin */ public native void execSummaryStatsScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, @Cast("bool") boolean biasCorrected); public native void execSummaryStatsScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, @Cast("bool") boolean biasCorrected); public native void execSummaryStatsScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, @Cast("bool") boolean biasCorrected); /** * @@ -1521,27 +1562,21 @@ public native void execSummaryStatsScalar(@Cast("Nd4jPointer*") PointerPointer e */ public native void execSummaryStats(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, @Cast("bool") boolean biasCorrected); public native void execSummaryStats(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, @Cast("bool") boolean biasCorrected); public native void execSummaryStats(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, @Cast("bool") boolean biasCorrected); /** * @@ -1556,35 +1591,26 @@ public native void execSummaryStats(@Cast("Nd4jPointer*") PointerPointer extraPo */ public native void execSummaryStatsTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape, @Cast("bool") boolean biasCorrected, @Cast("Nd4jLong*") LongPointer tadShapeInfo, @Cast("Nd4jLong*") LongPointer tadOffsets); public native void execSummaryStatsTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape, @Cast("bool") boolean biasCorrected, @Cast("Nd4jLong*") LongBuffer tadShapeInfo, @Cast("Nd4jLong*") LongBuffer tadOffsets); public native void execSummaryStatsTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape, @Cast("bool") boolean biasCorrected, @Cast("Nd4jLong*") long[] tadShapeInfo, @Cast("Nd4jLong*") long[] tadOffsets); @@ -1600,112 +1626,82 @@ public native void execSummaryStatsTad(@Cast("Nd4jPointer*") PointerPointer extr */ public native void execTransformFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); public native void execTransformFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); public native void execTransformFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); public native void execTransformSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); public native void execTransformSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); public native void execTransformSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); public native void execTransformBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); public native void execTransformBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); public native void execTransformBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); public native void execTransformAny(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); public native void execTransformAny(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); public native void execTransformAny(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); public native void execTransformStrict(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); public native void execTransformStrict(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); public native void execTransformStrict(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); /** @@ -1723,81 +1719,57 @@ public native void execTransformStrict(@Cast("Nd4jPointer*") PointerPointer extr */ public native void execScalarTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hScalars, @Cast("Nd4jLong*") LongPointer hScalarShapeInfo, - Pointer dScalars, @Cast("Nd4jLong*") LongPointer dScalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbScalars, @Cast("Nd4jLong*") LongPointer hScalarShapeInfo, @Cast("Nd4jLong*") LongPointer dScalarShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape, @Cast("Nd4jLong*") LongPointer tadShapeInfo, @Cast("Nd4jLong*") LongPointer tadOffsets, @Cast("Nd4jLong*") LongPointer tadShapeInfoZ, @Cast("Nd4jLong*") LongPointer tadOffsetsZ); public native void execScalarTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hScalars, @Cast("Nd4jLong*") LongBuffer hScalarShapeInfo, - Pointer dScalars, @Cast("Nd4jLong*") LongBuffer dScalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbScalars, @Cast("Nd4jLong*") LongBuffer hScalarShapeInfo, @Cast("Nd4jLong*") LongBuffer dScalarShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape, @Cast("Nd4jLong*") LongBuffer tadShapeInfo, @Cast("Nd4jLong*") LongBuffer tadOffsets, @Cast("Nd4jLong*") LongBuffer tadShapeInfoZ, @Cast("Nd4jLong*") LongBuffer tadOffsetsZ); public native void execScalarTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hScalars, @Cast("Nd4jLong*") long[] hScalarShapeInfo, - Pointer dScalars, @Cast("Nd4jLong*") long[] dScalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbScalars, @Cast("Nd4jLong*") long[] hScalarShapeInfo, @Cast("Nd4jLong*") long[] dScalarShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape, @Cast("Nd4jLong*") long[] tadShapeInfo, @Cast("Nd4jLong*") long[] tadOffsets, @Cast("Nd4jLong*") long[] tadShapeInfoZ, @Cast("Nd4jLong*") long[] tadOffsetsZ); public native void execScalarBoolTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, - Pointer hScalars, @Cast("Nd4jLong*") LongPointer hScalarShapeInfo, - Pointer dScalars, @Cast("Nd4jLong*") LongPointer dScalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeInfo, @Cast("Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbScalars, @Cast("Nd4jLong*") LongPointer hScalarShapeInfo, @Cast("Nd4jLong*") LongPointer dScalarShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, @Cast("Nd4jLong*") LongPointer dDimensionShape, @Cast("Nd4jLong*") LongPointer tadShapeInfo, @Cast("Nd4jLong*") LongPointer tadOffsets, @Cast("Nd4jLong*") LongPointer tadShapeInfoZ, @Cast("Nd4jLong*") LongPointer tadOffsetsZ); public native void execScalarBoolTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, - Pointer hScalars, @Cast("Nd4jLong*") LongBuffer hScalarShapeInfo, - Pointer dScalars, @Cast("Nd4jLong*") LongBuffer dScalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeInfo, @Cast("Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbScalars, @Cast("Nd4jLong*") LongBuffer hScalarShapeInfo, @Cast("Nd4jLong*") LongBuffer dScalarShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, @Cast("Nd4jLong*") LongBuffer dDimensionShape, @Cast("Nd4jLong*") LongBuffer tadShapeInfo, @Cast("Nd4jLong*") LongBuffer tadOffsets, @Cast("Nd4jLong*") LongBuffer tadShapeInfoZ, @Cast("Nd4jLong*") LongBuffer tadOffsetsZ); public native void execScalarBoolTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeInfo, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeInfo, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, - Pointer hScalars, @Cast("Nd4jLong*") long[] hScalarShapeInfo, - Pointer dScalars, @Cast("Nd4jLong*") long[] dScalarShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeInfo, @Cast("Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeInfo, @Cast("Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbScalars, @Cast("Nd4jLong*") long[] hScalarShapeInfo, @Cast("Nd4jLong*") long[] dScalarShapeInfo, Pointer extraParams, - Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, - Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape, + OpaqueDataBuffer dbDimension, @Cast("Nd4jLong*") long[] hDimensionShape, @Cast("Nd4jLong*") long[] dDimensionShape, @Cast("Nd4jLong*") long[] tadShapeInfo, @Cast("Nd4jLong*") long[] tadOffsets, @Cast("Nd4jLong*") long[] tadShapeInfoZ, @Cast("Nd4jLong*") long[] tadOffsetsZ); @@ -2160,10 +2132,8 @@ public native void deleteTadPack(OpaqueTadPack ptr); * @param zTadOffsets */ public native void pullRows(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo, - Pointer z, @Cast("Nd4jLong*") LongPointer zShapeInfo, - Pointer dz, @Cast("Nd4jLong*") LongPointer dzShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer xShapeInfo, @Cast("Nd4jLong*") LongPointer dxShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer zShapeInfo, @Cast("Nd4jLong*") LongPointer dzShapeInfo, @Cast("Nd4jLong") long n, @Cast("Nd4jLong*") LongPointer indexes, @Cast("Nd4jLong*") LongPointer tadShapeInfo, @@ -2171,10 +2141,8 @@ public native void pullRows(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongPointer zTadShapeInfo, @Cast("Nd4jLong*") LongPointer zTadOffsets); public native void pullRows(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo, - Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo, - Pointer z, @Cast("Nd4jLong*") LongBuffer zShapeInfo, - Pointer dz, @Cast("Nd4jLong*") LongBuffer dzShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer xShapeInfo, @Cast("Nd4jLong*") LongBuffer dxShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer zShapeInfo, @Cast("Nd4jLong*") LongBuffer dzShapeInfo, @Cast("Nd4jLong") long n, @Cast("Nd4jLong*") LongBuffer indexes, @Cast("Nd4jLong*") LongBuffer tadShapeInfo, @@ -2182,10 +2150,8 @@ public native void pullRows(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer zTadShapeInfo, @Cast("Nd4jLong*") LongBuffer zTadOffsets); public native void pullRows(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo, - Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo, - Pointer z, @Cast("Nd4jLong*") long[] zShapeInfo, - Pointer dz, @Cast("Nd4jLong*") long[] dzShapeInfo, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] xShapeInfo, @Cast("Nd4jLong*") long[] dxShapeInfo, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] zShapeInfo, @Cast("Nd4jLong*") long[] dzShapeInfo, @Cast("Nd4jLong") long n, @Cast("Nd4jLong*") long[] indexes, @Cast("Nd4jLong*") long[] tadShapeInfo, @@ -2451,20 +2417,17 @@ public native void execAggregateBatch(@Cast("Nd4jPointer*") PointerPointer extra public native void execRandom(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeBuffer, @Cast("Nd4jLong*") LongPointer dZShapeBuffer, Pointer extraArguments); public native void execRandom(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeBuffer, @Cast("Nd4jLong*") LongBuffer dZShapeBuffer, Pointer extraArguments); public native void execRandom(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeBuffer, @Cast("Nd4jLong*") long[] dZShapeBuffer, Pointer extraArguments); /** @@ -2483,32 +2446,23 @@ public native void execRandom(@Cast("Nd4jPointer*") PointerPointer extraPointers public native void execRandom3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeBuffer, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeBuffer, - Pointer hY, @Cast("Nd4jLong*") LongPointer hYShapeBuffer, - Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeBuffer, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeBuffer, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeBuffer, @Cast("Nd4jLong*") LongPointer dXShapeBuffer, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongPointer hYShapeBuffer, @Cast("Nd4jLong*") LongPointer dYShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeBuffer, @Cast("Nd4jLong*") LongPointer dZShapeBuffer, Pointer extraArguments); public native void execRandom3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeBuffer, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeBuffer, - Pointer hY, @Cast("Nd4jLong*") LongBuffer hYShapeBuffer, - Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeBuffer, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeBuffer, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeBuffer, @Cast("Nd4jLong*") LongBuffer dXShapeBuffer, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") LongBuffer hYShapeBuffer, @Cast("Nd4jLong*") LongBuffer dYShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeBuffer, @Cast("Nd4jLong*") LongBuffer dZShapeBuffer, Pointer extraArguments); public native void execRandom3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeBuffer, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeBuffer, - Pointer hY, @Cast("Nd4jLong*") long[] hYShapeBuffer, - Pointer dY, @Cast("Nd4jLong*") long[] dYShapeBuffer, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeBuffer, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeBuffer, @Cast("Nd4jLong*") long[] dXShapeBuffer, + OpaqueDataBuffer dbY, @Cast("Nd4jLong*") long[] hYShapeBuffer, @Cast("Nd4jLong*") long[] dYShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeBuffer, @Cast("Nd4jLong*") long[] dZShapeBuffer, Pointer extraArguments); /** @@ -2525,26 +2479,20 @@ public native void execRandom3(@Cast("Nd4jPointer*") PointerPointer extraPointer public native void execRandom2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hX, @Cast("Nd4jLong*") LongPointer hXShapeBuffer, - Pointer dX, @Cast("Nd4jLong*") LongPointer dXShapeBuffer, - Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeBuffer, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer hXShapeBuffer, @Cast("Nd4jLong*") LongPointer dXShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongPointer hZShapeBuffer, @Cast("Nd4jLong*") LongPointer dZShapeBuffer, Pointer extraArguments); public native void execRandom2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hX, @Cast("Nd4jLong*") LongBuffer hXShapeBuffer, - Pointer dX, @Cast("Nd4jLong*") LongBuffer dXShapeBuffer, - Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeBuffer, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer hXShapeBuffer, @Cast("Nd4jLong*") LongBuffer dXShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") LongBuffer hZShapeBuffer, @Cast("Nd4jLong*") LongBuffer dZShapeBuffer, Pointer extraArguments); public native void execRandom2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, @Cast("Nd4jPointer") Pointer state, - Pointer hX, @Cast("Nd4jLong*") long[] hXShapeBuffer, - Pointer dX, @Cast("Nd4jLong*") long[] dXShapeBuffer, - Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeBuffer, - Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeBuffer, + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] hXShapeBuffer, @Cast("Nd4jLong*") long[] dXShapeBuffer, + OpaqueDataBuffer dbZ, @Cast("Nd4jLong*") long[] hZShapeBuffer, @Cast("Nd4jLong*") long[] dZShapeBuffer, Pointer extraArguments); @@ -2587,52 +2535,6 @@ public native void reSeedBuffer(@Cast("Nd4jPointer*") PointerPointer extraPointe */ public native void destroyRandom(@Cast("Nd4jPointer") Pointer ptrRandom); -/** - * Grid operations - */ - - - - -/** - * - * @param extras - * @param opTypeA - * @param opNumA - * @param opTypeB - * @param opNumB - * @param N - * @param dx - * @param xShapeInfo - * @param dy - * @param yShapeInfo - * @param dz - * @param zShapeInfo - * @param extraA - * @param extraB - * @param scalarA - * @param scalarB - */ - /* -ND4J_EXPORT void execMetaPredicateShape(Nd4jPointer *extras, - const int opTypeA, - const int opNumA, - const int opTypeB, - const int opNumB, - Nd4jLong N, - void *hX, Nd4jLong *hXShapeBuffer, - void *dX, Nd4jLong *dXShapeBuffer, - void *hY, Nd4jLong *hYShapeBuffer, - void *dY, Nd4jLong *dYShapeBuffer, - void *hZ, Nd4jLong *hZShapeBuffer, - void *dZ, Nd4jLong *dZShapeBuffer, - void *extraA, - void *extraB, - double scalarA, - double scalarB); - -*/ - /** * * @param data @@ -2795,23 +2697,20 @@ public native @Cast("Nd4jPointer") Pointer pointerForAddress(@Cast("Nd4jLong") l * @return */ public native void tear(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo, - Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo, - @Cast("Nd4jPointer*") PointerPointer targets, @Cast("Nd4jLong*") LongPointer zShapeInfo, - @Cast("Nd4jLong*") LongPointer tadShapeInfo, - @Cast("Nd4jLong*") LongPointer tadOffsets); + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongPointer xShapeInfo, @Cast("Nd4jLong*") LongPointer dxShapeInfo, + @Cast("Nd4jPointer*") PointerPointer targets, @Cast("Nd4jLong*") LongPointer zShapeInfo, + @Cast("Nd4jLong*") LongPointer tadShapeInfo, + @Cast("Nd4jLong*") LongPointer tadOffsets); public native void tear(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo, - Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo, - @Cast("Nd4jPointer*") PointerPointer targets, @Cast("Nd4jLong*") LongBuffer zShapeInfo, - @Cast("Nd4jLong*") LongBuffer tadShapeInfo, - @Cast("Nd4jLong*") LongBuffer tadOffsets); + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") LongBuffer xShapeInfo, @Cast("Nd4jLong*") LongBuffer dxShapeInfo, + @Cast("Nd4jPointer*") PointerPointer targets, @Cast("Nd4jLong*") LongBuffer zShapeInfo, + @Cast("Nd4jLong*") LongBuffer tadShapeInfo, + @Cast("Nd4jLong*") LongBuffer tadOffsets); public native void tear(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo, - Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo, - @Cast("Nd4jPointer*") PointerPointer targets, @Cast("Nd4jLong*") long[] zShapeInfo, - @Cast("Nd4jLong*") long[] tadShapeInfo, - @Cast("Nd4jLong*") long[] tadOffsets); + OpaqueDataBuffer dbX, @Cast("Nd4jLong*") long[] xShapeInfo, @Cast("Nd4jLong*") long[] dxShapeInfo, + @Cast("Nd4jPointer*") PointerPointer targets, @Cast("Nd4jLong*") long[] zShapeInfo, + @Cast("Nd4jLong*") long[] tadShapeInfo, + @Cast("Nd4jLong*") long[] tadOffsets); public native @Cast("Nd4jLong") long encodeBitmap(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer dx, @Cast("Nd4jLong*") LongPointer xShapeInfo, @Cast("Nd4jLong") long N, IntPointer dz, float threshold); public native @Cast("Nd4jLong") long encodeBitmap(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer dx, @Cast("Nd4jLong*") LongBuffer xShapeInfo, @Cast("Nd4jLong") long N, IntBuffer dz, float threshold); @@ -3108,6 +3007,8 @@ public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") bool public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer); public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); public native void setGraphContextOutputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); +public native void setGraphContextInputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer buffer, Pointer shapeInfo, Pointer specialShapeInfo); +public native void setGraphContextOutputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer buffer, Pointer shapeInfo, Pointer specialShapeInfo); public native void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments); public native void setGraphContextTArguments(OpaqueContext ptr, DoubleBuffer arguments, int numberOfArguments); public native void setGraphContextTArguments(OpaqueContext ptr, double[] arguments, int numberOfArguments); @@ -3140,6 +3041,28 @@ public native @Cast("Nd4jPointer") Pointer lcCopyStream(OpaqueLaunchContext lc); public native @Cast("Nd4jPointer") Pointer lcBlasHandle(OpaqueLaunchContext lc); public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc); +public native OpaqueDataBuffer allocateDataBuffer(@Cast("Nd4jLong") long elements, int dataType, @Cast("bool") boolean allocateBoth); +public native OpaqueDataBuffer dbCreateView(OpaqueDataBuffer dataBuffer, @Cast("Nd4jLong") long length, @Cast("Nd4jLong") long offset); +public native @Cast("Nd4jPointer") Pointer dbPrimaryBuffer(OpaqueDataBuffer dataBuffer); +public native @Cast("Nd4jPointer") Pointer dbSpecialBuffer(OpaqueDataBuffer dataBuffer); +public native void dbExpandBuffer(OpaqueDataBuffer dataBuffer, @Cast("Nd4jLong") long elements); +public native void dbAllocatePrimaryBuffer(OpaqueDataBuffer dataBuffer); +public native void dbAllocateSpecialBuffer(OpaqueDataBuffer dataBuffer); +public native void dbSetPrimaryBuffer(OpaqueDataBuffer dataBuffer, @Cast("Nd4jPointer") Pointer primaryBuffer, @Cast("Nd4jLong") long numBytes); +public native void dbSetSpecialBuffer(OpaqueDataBuffer dataBuffer, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong") long numBytes); +public native void dbSyncToSpecial(OpaqueDataBuffer dataBuffer); +public native void dbSyncToPrimary(OpaqueDataBuffer dataBuffer); +public native int dbLocality(OpaqueDataBuffer dataBuffer); +public native int dbDeviceId(OpaqueDataBuffer dataBuffer); +public native void dbSetDeviceId(OpaqueDataBuffer dataBuffer, int deviceId); +public native void dbTickHostRead(OpaqueDataBuffer dataBuffer); +public native void dbTickHostWrite(OpaqueDataBuffer dataBuffer); +public native void dbTickDeviceRead(OpaqueDataBuffer dataBuffer); +public native void dbTickDeviceWrite(OpaqueDataBuffer dataBuffer); +public native void dbClose(OpaqueDataBuffer dataBuffer); +public native void deleteDataBuffer(OpaqueDataBuffer dataBuffer); +public native void dbExpand(OpaqueDataBuffer dataBuffer, @Cast("Nd4jLong") long elements); + public native int binaryLevel(); public native int optimalLevel(); @@ -3636,6 +3559,8 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); // #include // #include // #include +// #include +// #include @@ -3847,10 +3772,13 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); * @param writeList * @param readList */ - // TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list - - // TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list + public static native void registerSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList); + public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/); + public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList); + public static native void registerPrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList); + public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/); + public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList); /** * This method returns buffer pointer offset by given number of elements, wrt own data type @@ -6399,10 +6327,12 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); public native void setInputArray(int index, NDArray array, @Cast("bool") boolean removable/*=false*/); public native void setInputArray(int index, NDArray array); public native void setInputArray(int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); + public native void setInputArray(int index, Pointer databuffer, Pointer shapeInfo, Pointer specialShapeInfo); public native void setOutputArray(int index, NDArray array, @Cast("bool") boolean removable/*=false*/); public native void setOutputArray(int index, NDArray array); public native void setOutputArray(int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); + public native void setOutputArray(int index, Pointer databuffer, Pointer shapeInfo, Pointer specialShapeInfo); public native void setTArguments(DoublePointer arguments, int numberOfArguments); public native void setTArguments(DoubleBuffer arguments, int numberOfArguments); @@ -9604,7 +9534,7 @@ public static final int PREALLOC_SIZE = 33554432; // #define BROADCAST(NAME) nd4j::BroadcastOpsTuple::custom(nd4j::scalar::NAME, nd4j::pairwise::NAME, nd4j::broadcast::NAME) // #define BROADCAST_BOOL(NAME) nd4j::BroadcastBoolOpsTuple::custom(nd4j::scalar::NAME, nd4j::pairwise::NAME, nd4j::broadcast::NAME) - +public static final int ALL_STRINGS =UTF32; public static final int ALL_INDICES =INT64; public static final int ALL_INTS =UINT64; public static final int ALL_FLOATS =BFLOAT16; @@ -11167,6 +11097,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // #define PARAMETRIC_D() [&] (Parameters &p) -> Context* + +// #ifdef __CUDABLAS__ +// #endif + // #endif @@ -12014,6 +11948,9 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // #include // #include // #include +// #include +// #include +// #include // #include // #include // #include diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java index 9d067b5bc..ec5c25d86 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java @@ -36,6 +36,7 @@ import java.util.Scanner; value = {@Platform(define = "LIBND4J_ALL_OPS", include = { "memory/MemoryType.h", "array/DataType.h", + "array/DataBuffer.h", "array/ConstantDataBuffer.h", "array/ConstantDescriptor.h", "array/TadPack.h", @@ -160,6 +161,7 @@ public class Nd4jCpuPresets implements InfoMapper, BuildEnabled { .put(new Info("OpaqueVariablesSet").pointerTypes("OpaqueVariablesSet")) .put(new Info("OpaqueVariable").pointerTypes("OpaqueVariable")) .put(new Info("OpaqueConstantDataBuffer").pointerTypes("OpaqueConstantDataBuffer")) + .put(new Info("OpaqueDataBuffer").pointerTypes("OpaqueDataBuffer")) .put(new Info("OpaqueContext").pointerTypes("OpaqueContext")) .put(new Info("OpaqueRandomGenerator").pointerTypes("OpaqueRandomGenerator")) .put(new Info("OpaqueLaunchContext").pointerTypes("OpaqueLaunchContext")) @@ -186,6 +188,7 @@ public class Nd4jCpuPresets implements InfoMapper, BuildEnabled { .put(new Info("std::pair").pointerTypes("IntIntPair").define()) .put(new Info("std::vector >").pointerTypes("IntVectorVector").define()) .put(new Info("std::vector >").pointerTypes("LongVectorVector").define()) + .put(new Info("std::vector").pointerTypes("ConstNDArrayVector").define()) .put(new Info("std::vector").pointerTypes("NDArrayVector").define()) .put(new Info("nd4j::graph::ResultWrapper").base("org.nd4j.nativeblas.ResultWrapperAbstraction").define()) .put(new Info("bool").cast().valueTypes("boolean").pointerTypes("BooleanPointer", "boolean[]")) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/resources/nd4j-native.properties b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/resources/nd4j-native.properties index 4690d54f6..0b2489b53 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/resources/nd4j-native.properties +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/resources/nd4j-native.properties @@ -28,7 +28,7 @@ native.ops= org.nd4j.nativeblas.Nd4jCpu ndarrayfactory.class = org.nd4j.linalg.cpu.nativecpu.CpuNDArrayFactory ndarray.order = c resourcemanager_state = false -databufferfactory = org.nd4j.linalg.api.buffer.factory.DefaultDataBufferFactory +databufferfactory = org.nd4j.linalg.cpu.nativecpu.buffer.DefaultDataBufferFactory workspacemanager = org.nd4j.linalg.cpu.nativecpu.workspace.CpuWorkspaceManager alloc = javacpp fft = org.nd4j.linalg.fft.DefaultFFTInstance diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java index de967ac40..15d6bd273 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java @@ -1750,7 +1750,7 @@ public class ShapeOpValidation extends BaseOpValidation { Nd4j.getExecutioner().exec(op); - val output = op.outputArguments()[0]; + val output = op.outputArguments().get(0); assertEquals(exp, output); } @@ -2458,7 +2458,7 @@ public class ShapeOpValidation extends BaseOpValidation { Nd4j.exec(op); // System.out.println(in); -// System.out.println(op.outputArguments()[0]); +// System.out.println(op.outputArguments().get(0)); assertEquals(exp, op.getOutputArgument(0)); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/LoneTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/LoneTest.java index 89a72e5ac..fd1e14423 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/LoneTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/LoneTest.java @@ -26,6 +26,8 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh; +import org.nd4j.linalg.api.ops.util.PrintVariable; +import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.checkutil.NDArrayCreationUtil; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; @@ -184,10 +186,22 @@ public class LoneTest extends BaseNd4jTest { assertEquals(max - 1, currentArgMax); } + @Test + public void testRPF() { + val array = Nd4j.createFromArray(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12).reshape(2, 2, 3); + + log.info("--------"); + + val tad = array.tensorAlongDimension(1, 1, 2); + Nd4j.exec(new PrintVariable(tad, false)); + log.info("TAD native shapeInfo: {}", tad.shapeInfoDataBuffer().asLong()); + log.info("TAD Java shapeInfo: {}", tad.shapeInfoJava()); + log.info("TAD:\n{}", tad); + } @Test public void testConcat3D_Vstack_C() { - val shape = new long[]{1, 1000, 150}; + val shape = new long[]{1, 1000, 20}; List cArrays = new ArrayList<>(); List fArrays = new ArrayList<>(); @@ -200,15 +214,17 @@ public class LoneTest extends BaseNd4jTest { Nd4j.getExecutioner().commit(); - long time1 = System.currentTimeMillis(); - INDArray res = Nd4j.vstack(cArrays); - long time2 = System.currentTimeMillis(); + val time1 = System.currentTimeMillis(); + val res = Nd4j.vstack(cArrays); + val time2 = System.currentTimeMillis(); // log.info("Time spent: {} ms", time2 - time1); for (int e = 0; e < 32; e++) { - INDArray tad = res.tensorAlongDimension(e, 1, 2); + val tad = res.tensorAlongDimension(e, 1, 2); + assertEquals("Failed for TAD [" + e + "]",(double) e, tad.meanNumber().doubleValue(), 1e-5); + assertEquals((double) e, tad.getDouble(0), 1e-5); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/NDArrayTestsFortran.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/NDArrayTestsFortran.java index 3aa43d858..8657d061e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/NDArrayTestsFortran.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/NDArrayTestsFortran.java @@ -21,6 +21,7 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.After; import org.junit.Before; +import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -28,6 +29,7 @@ import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.comparison.Eps; +import org.nd4j.linalg.api.ops.util.PrintVariable; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.checkutil.NDArrayCreationUtil; import org.nd4j.linalg.executors.ExecutorServiceProvider; @@ -678,6 +680,8 @@ public class NDArrayTestsFortran extends BaseNd4jTest { public void testPutSlice() { INDArray n = Nd4j.linspace(1, 27, 27, DataType.DOUBLE).reshape(3, 3, 3); INDArray newSlice = Nd4j.create(DataType.DOUBLE, 3, 3); + Nd4j.exec(new PrintVariable(newSlice)); + log.info("Slice: {}", newSlice); n.putSlice(0, newSlice); assertEquals(getFailureMessage(), newSlice, n.slice(0)); @@ -993,14 +997,10 @@ public class NDArrayTestsFortran extends BaseNd4jTest { INDArray put = Nd4j.create(new double[] {5, 6}); row1.putRow(1, put); - INDArray row1Fortran = Nd4j.create(new double[][] {{1, 3}, {2, 4}}); INDArray putFortran = Nd4j.create(new double[] {5, 6}); row1Fortran.putRow(1, putFortran); assertEquals(row1, row1Fortran); - - - } @@ -1036,6 +1036,7 @@ public class NDArrayTestsFortran extends BaseNd4jTest { } @Test + @Ignore public void testTensorDot() { INDArray oneThroughSixty = Nd4j.arange(60).reshape('f', 3, 4, 5).castTo(DataType.DOUBLE); INDArray oneThroughTwentyFour = Nd4j.arange(24).reshape('f', 4, 3, 2).castTo(DataType.DOUBLE); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index 81540f3c4..534720b29 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -83,6 +83,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.Axpy; import org.nd4j.linalg.api.ops.impl.transforms.same.Sign; import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh; import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh; +import org.nd4j.linalg.api.ops.util.PrintVariable; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.checkutil.NDArrayCreationUtil; @@ -439,6 +440,7 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @Ignore public void testMmulOp() throws Exception { INDArray arr = Nd4j.create(new double[][] {{1, 2, 3}, {4, 5, 6}}); INDArray z = Nd4j.create(2, 2); @@ -2863,7 +2865,6 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(25, Nd4j.getBlasWrapper().dot(row, row), 1e-1); } - @Test public void testIdentity() { INDArray eye = Nd4j.eye(5); @@ -6077,6 +6078,7 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + //@Ignore public void testMatmul_128by256() { val mA = Nd4j.create(128, 156).assign(1.0f); val mB = Nd4j.create(156, 256).assign(1.0f); @@ -6244,6 +6246,14 @@ public class Nd4jTestsC extends BaseNd4jTest { } + @Test + public void testScalarPrint_1() { + val scalar = Nd4j.scalar(3.0f); + + Nd4j.exec(new PrintVariable(scalar, true)); + } + + @Test public void testValueArrayOf_1() { val vector = Nd4j.valueArrayOf(new long[] {5}, 2f, DataType.FLOAT); @@ -6986,6 +6996,7 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @Ignore public void testMatmul_vs_tf() throws Exception { // uncomment this line to initialize & propagate sgemm/dgemm pointer @@ -7168,6 +7179,16 @@ public class Nd4jTestsC extends BaseNd4jTest { } } + @Test + public void testScalarEquality_1() { + val x = Nd4j.scalar(1.0f); + val e = Nd4j.scalar(3.0f); + + x.addi(2.0f); + + assertEquals(e, x); + } + @Test public void testStack(){ INDArray in = Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape(3,4); @@ -8217,6 +8238,17 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(exp, out); } + @Test + public void testPutOverwrite(){ + INDArray arr = Nd4j.create(DataType.DOUBLE, 10); + arr.putScalar(0, 10); + System.out.println(arr); + INDArray arr2 = Nd4j.createFromArray(3.0, 3.0, 3.0); + val view = arr.get(new INDArrayIndex[]{NDArrayIndex.interval(1, 4)}); + view.assign(arr2); + System.out.println(arr); + } + @Test public void testEmptyReshapingMinus1(){ INDArray arr0 = Nd4j.create(DataType.FLOAT, 2, 0); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DoubleDataBufferTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DoubleDataBufferTest.java index 2b3f17072..41c0e3e2d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DoubleDataBufferTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DoubleDataBufferTest.java @@ -36,6 +36,7 @@ import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.util.SerializationUtils; import java.io.*; +import java.util.Arrays; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -224,7 +225,7 @@ public class DoubleDataBufferTest extends BaseNd4jTest { double[] old = buffer.asDouble(); buffer.reallocate(6); assertEquals(6, buffer.capacity()); - assertArrayEquals(old, buffer.asDouble(), 1e-1); + assertArrayEquals(old, Arrays.copyOf(buffer.asDouble(), 4), 1e-1); } @Test @@ -239,7 +240,7 @@ public class DoubleDataBufferTest extends BaseNd4jTest { assertEquals(4, buffer.capacity()); buffer.reallocate(6); assertEquals(6, buffer.capacity()); - assertArrayEquals(old, buffer.asDouble(), 1e-1); + assertArrayEquals(old, Arrays.copyOf(buffer.asDouble(), 4), 1e-1); workspace.close(); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestSerializationDoubleToFloat.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestSerializationDoubleToFloat.java index 6324037e2..1c6c3ac44 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestSerializationDoubleToFloat.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestSerializationDoubleToFloat.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ndarray; +import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.After; import org.junit.Test; @@ -24,18 +25,23 @@ import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.util.DataTypeUtil; +import org.nd4j.linalg.api.ops.util.PrintAffinity; +import org.nd4j.linalg.api.ops.util.PrintVariable; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.ops.transforms.Transforms; +import org.nd4j.nativeblas.NativeOpsHolder; import java.io.*; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; /** * Created by susaneraly on 7/2/16. */ +@Slf4j @RunWith(Parameterized.class) public class TestSerializationDoubleToFloat extends BaseNd4jTest { @@ -53,7 +59,7 @@ public class TestSerializationDoubleToFloat extends BaseNd4jTest { @Test public void testSerializationFullArrayNd4jWriteRead() throws Exception { - int length = 100; + int length = 4; //WRITE OUT A DOUBLE ARRAY //Hack before setting datatype - fix already in r119_various branch @@ -61,7 +67,7 @@ public class TestSerializationDoubleToFloat extends BaseNd4jTest { val initialType = Nd4j.dataType(); Nd4j.setDataType(DataType.DOUBLE); - INDArray arr = Nd4j.linspace(1, length, length).reshape('c', 10, 10); + INDArray arr = Nd4j.linspace(1, length, length).reshape('c', 2, 2); arr.subi(50.0123456); //assures positive and negative numbers with decimal points ByteArrayOutputStream baos = new ByteArrayOutputStream(); @@ -74,9 +80,11 @@ public class TestSerializationDoubleToFloat extends BaseNd4jTest { //Nd4j.create(1); DataTypeUtil.setDTypeForContext(DataType.FLOAT); System.out.println("The data opType is " + Nd4j.dataType()); - INDArray arr1 = Nd4j.linspace(1, length, length).reshape('c', 10, 10); + INDArray arr1 = Nd4j.linspace(1, length, length).reshape('c', 2, 2); arr1.subi(50.0123456); + log.info("A ---------------"); + INDArray arr2; try (DataInputStream dis = new DataInputStream(new ByteArrayInputStream(bytes))) { arr2 = Nd4j.read(dis); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/broadcast/BasicBroadcastTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/broadcast/BasicBroadcastTests.java index 7b72f4bae..ceac20e51 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/broadcast/BasicBroadcastTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/broadcast/BasicBroadcastTests.java @@ -27,6 +27,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RealDivOp; +import org.nd4j.linalg.api.ops.util.PrintVariable; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; @@ -63,6 +64,9 @@ public class BasicBroadcastTests extends BaseNd4jTest { val y = Nd4j.createFromArray(new float[]{1.f, 1.f, 1.f, 1.f}).reshape(2, 2); val e = Nd4j.create(DataType.FLOAT, 3, 2, 2).assign(1.f); + //Nd4j.exec(new PrintVariable(x, "X array")); + //Nd4j.exec(new PrintVariable(y, "Y array")); + val z = x.add(y); assertEquals(e, z); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/CrashTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/CrashTest.java index c6efda9b0..d0bcb3975 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/CrashTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/CrashTest.java @@ -18,6 +18,7 @@ package org.nd4j.linalg.crash; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.RandomUtils; +import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -41,6 +42,7 @@ import org.nd4j.linalg.indexing.conditions.Conditions; */ @Slf4j @RunWith(Parameterized.class) +@Ignore public class CrashTest extends BaseNd4jTest { public CrashTest(Nd4jBackend backend) { super(backend); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index dfbf5adf5..0ae56350d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -419,7 +419,7 @@ public class CustomOpsTests extends BaseNd4jTest { Nd4j.getExecutioner().exec(op); - assertEquals(1, op.outputArguments().length); + assertEquals(1, op.outputArguments().size()); val output = op.getOutputArgument(0); assertArrayEquals(new long[]{5, 10}, output.shape()); @@ -435,7 +435,7 @@ public class CustomOpsTests extends BaseNd4jTest { Nd4j.getExecutioner().exec(op); - assertEquals(1, op.outputArguments().length); + assertEquals(1, op.outputArguments().size()); val output = op.getOutputArgument(0); assertArrayEquals(new long[]{5, 10}, output.shape()); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/ExpandableOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/ExpandableOpsTests.java new file mode 100644 index 000000000..38a5ab763 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/ExpandableOpsTests.java @@ -0,0 +1,60 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.custom; + +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import org.junit.Test; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ops.compat.CompatStringSplit; +import org.nd4j.linalg.api.ops.util.PrintVariable; +import org.nd4j.linalg.factory.Nd4j; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +/** + * This is special test suit: we test operations that on C++ side modify arrays that come from Java + */ +@Slf4j +public class ExpandableOpsTests { + + @Test + public void testCompatStringSplit_1() throws Exception { + val array = Nd4j.create("first string", "second"); + val delimiter = Nd4j.create(" "); + + val exp0 = Nd4j.createFromArray(new long[] {0,0, 0,1, 1,0}); + val exp1 = Nd4j.create("first", "string", "second"); + + val results = Nd4j.exec(new CompatStringSplit(array, delimiter)); + assertNotNull(results); + assertEquals(2, results.length); + + assertEquals(exp0, results[0]); + assertEquals(exp1, results[1]); + } + + @Test + public void test() { + val arr = Nd4j.createFromArray(0, 1, 2, 3, 4, 5, 6, 7, 8).reshape(3, 3); + Nd4j.exec(new PrintVariable(arr)); + + val row = arr.getRow(1); + Nd4j.exec(new PrintVariable(row)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/multithreading/MultithreadedTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/multithreading/MultithreadedTests.java new file mode 100644 index 000000000..27cdf5a42 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/multithreading/MultithreadedTests.java @@ -0,0 +1,83 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.multithreading; + +import lombok.val; +import org.junit.Test; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.concurrent.CopyOnWriteArrayList; + +import static org.junit.Assert.assertEquals; + +/** + * @author raver119@gmail.com + */ +public class MultithreadedTests { + + @Test + public void basicMigrationTest_1() throws Exception { + if (Nd4j.getAffinityManager().getNumberOfDevices() < 2) + return; + + val exp = Nd4j.create(DataType.INT32, 5, 5).assign(2); + + val hash = new HashSet(); + + // we're creating bunch of arrays on different devices + val list = new ArrayList(); + for (int e = 0; e < Nd4j.getAffinityManager().getNumberOfDevices(); e++) { + val t = e; + val thread = new Thread(new Runnable() { + @Override + public void run() { + for (int f = 0; f < 10; f++) { + val array = Nd4j.create(DataType.INT32, 5, 5).assign(1); + + // store current deviceId for further validation + hash.add(Nd4j.getAffinityManager().getDeviceForCurrentThread()); + + // make sure INDArray has proper affinity set + assertEquals(Nd4j.getAffinityManager().getDeviceForCurrentThread(), Nd4j.getAffinityManager().getDeviceForArray(array)); + + list.add(array); + } + }; + }); + + thread.start(); + thread.join(); + } + + // lets make sure all devices covered + assertEquals(Nd4j.getAffinityManager().getNumberOfDevices(), hash.size()); + + // make sure nothing failed in threads + assertEquals(10 * Nd4j.getAffinityManager().getNumberOfDevices(), list.size()); + + // now we're going to use arrays on current device, so data will be migrated + for (val arr:list) { + arr.addi(1); + + assertEquals(exp, arr); + } + } +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java index 042abca1f..0fc085abe 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java @@ -309,7 +309,7 @@ public class OpExecutionerTests extends BaseNd4jTest { val arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1); val softMax = new SoftMax(arr); opExecutioner.exec((CustomOp) softMax); - assertEquals(getFailureMessage(), 1.0, softMax.outputArguments()[0].sumNumber().doubleValue(), 1e-1); + assertEquals(getFailureMessage(), 1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1); } @@ -520,7 +520,7 @@ public class OpExecutionerTests extends BaseNd4jTest { // System.out.println("Data:" + input.data().length()); val softMax = new SoftMax(input); Nd4j.getExecutioner().exec((CustomOp) softMax); - assertEquals(assertion, softMax.outputArguments()[0]); + assertEquals(assertion, softMax.outputArguments().get(0)); } @@ -559,7 +559,7 @@ public class OpExecutionerTests extends BaseNd4jTest { INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1); val softMax = new SoftMax(arr); opExecutioner.exec((CustomOp) softMax); - assertEquals(getFailureMessage(), 1.0, softMax.outputArguments()[0].sumNumber().doubleValue(), 1e-1); + assertEquals(getFailureMessage(), 1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1); } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java index 3e7551ae4..4e16544b8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java @@ -326,7 +326,7 @@ public class OpExecutionerTestsC extends BaseNd4jTest { INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1); val softMax = new SoftMax(arr); opExecutioner.exec((CustomOp) softMax); - assertEquals(getFailureMessage(), 1.0, softMax.outputArguments()[0].sumNumber().doubleValue(), 1e-1); + assertEquals(getFailureMessage(), 1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1); } @Test @@ -426,12 +426,12 @@ public class OpExecutionerTestsC extends BaseNd4jTest { INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1); val softMax = new SoftMax(arr); opExecutioner.exec((CustomOp) softMax); - assertEquals(getFailureMessage(), 1.0, softMax.outputArguments()[0].sumNumber().doubleValue(), 1e-1); + assertEquals(getFailureMessage(), 1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1); INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); val softmax = new SoftMax(linspace.dup()); Nd4j.getExecutioner().exec((CustomOp) softmax); - assertEquals(linspace.rows(), softmax.outputArguments()[0].sumNumber().doubleValue(), 1e-1); + assertEquals(linspace.rows(), softmax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1); } @@ -440,7 +440,7 @@ public class OpExecutionerTestsC extends BaseNd4jTest { INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); val max = new SoftMax(linspace); Nd4j.getExecutioner().exec((CustomOp) max); - linspace.assign(max.outputArguments()[0]); + linspace.assign(max.outputArguments().get(0)); assertEquals(getFailureMessage(), linspace.getRow(0).sumNumber().doubleValue(), 1.0, 1e-1); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/PerformanceTrackerTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/PerformanceTrackerTests.java index 4d48a6a98..093fc2ac1 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/PerformanceTrackerTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/PerformanceTrackerTests.java @@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.After; import org.junit.Before; +import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -105,6 +106,7 @@ public class PerformanceTrackerTests extends BaseNd4jTest { } @Test + @Ignore public void testTrackerCpu_1() { if (!Nd4j.getExecutioner().getClass().getCanonicalName().toLowerCase().contains("native")) return; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java index fc6a034fd..8a06bd7e9 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java @@ -488,7 +488,7 @@ public class RandomTests extends BaseNd4jTest { @Test public void testLegacyDistribution1() { NormalDistribution distribution = new NormalDistribution(new DefaultRandom(), 0.0, 1.0); - INDArray z1 = distribution.sample(new int[] {1, 30000000}); + INDArray z1 = distribution.sample(new int[] {1, 1000000}); assertEquals(0.0, z1.meanNumber().doubleValue(), 0.01); assertEquals(1.0, z1.stdNumber().doubleValue(), 0.01); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTests.java index 97f591cff..30bdbfb37 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTests.java @@ -87,8 +87,6 @@ public class ConcatTests extends BaseNd4jTest { assertTrue(firstRet.isColumnVector()); INDArray secondRet = Nd4j.concat(1, first, second); assertTrue(secondRet.isRowVector()); - - } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java index d98e9218e..d5686bba5 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.workspace; import lombok.extern.slf4j.Slf4j; +import lombok.val; import org.junit.After; import org.junit.Before; import org.junit.Ignore; @@ -29,6 +30,7 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; import org.nd4j.linalg.api.memory.enums.*; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.util.PrintVariable; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; @@ -176,6 +178,7 @@ public class BasicWorkspaceTests extends BaseNd4jTest { @Test public void testLeverageTo2() { + val exp = Nd4j.scalar(15.0); try (Nd4jWorkspace wsOne = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(loopOverTimeConfig, "EXT")) { INDArray array1 = Nd4j.create(new double[] {1f, 2f, 3f, 4f, 5f}); @@ -192,6 +195,10 @@ public class BasicWorkspaceTests extends BaseNd4jTest { assertEquals(0, wsOne.getCurrentSize()); assertEquals(15f, array3.sumNumber().floatValue(), 0.01f); + + array2.assign(0); + + assertEquals(15f, array3.sumNumber().floatValue(), 0.01f); } try (Nd4jWorkspace wsTwo = diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java index 15249acc9..5ccbc54cc 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java +++ b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java @@ -81,12 +81,10 @@ public abstract class BaseDataBuffer implements DataBuffer { protected transient DataBuffer wrappedDataBuffer; protected transient long workspaceGenerationId = 0L; - //protected Collection referencing = Collections.synchronizedSet(new HashSet()); - //protected boolean isPersist = false; protected AllocationMode allocationMode; - protected transient Pointer pointer; - protected transient Indexer indexer; - //protected AtomicBoolean dirty = new AtomicBoolean(false); + + protected transient Indexer indexer = null; + protected transient Pointer pointer = null; protected transient boolean attached = false; protected transient MemoryWorkspace parentWorkspace; @@ -94,7 +92,6 @@ public abstract class BaseDataBuffer implements DataBuffer { // Allocator-related stuff. Moved down here to avoid opType casting. protected transient DataBuffer originalBuffer; protected transient long originalOffset = 0; - protected transient Long trackingPoint; protected transient boolean constant = false; protected transient boolean released = false; @@ -203,7 +200,6 @@ public abstract class BaseDataBuffer implements DataBuffer { this.originalOffset = offset; // + underlyingBuffer.originalOffset(); } - pointer = underlyingBuffer.pointer(); setIndexer(underlyingBuffer.indexer()); } @@ -217,378 +213,6 @@ public abstract class BaseDataBuffer implements DataBuffer { return originalBuffer; } - /** - * - * @param data - * @param copy - */ - public BaseDataBuffer(float[] data, boolean copy, long offset) { - this(data, copy); - this.offset = offset; - this.originalOffset = offset; - this.length = data.length - offset; - this.underlyingLength = data.length; - - } - - public BaseDataBuffer(float[] data, boolean copy, long offset, MemoryWorkspace workspace) { - this(data, copy, workspace); - this.offset = offset; - this.originalOffset = offset; - this.length = data.length - offset; - this.underlyingLength = data.length; - - } - - /** - * - * @param data - * @param copy - */ - public BaseDataBuffer(float[] data, boolean copy) { - allocationMode = AllocUtil.getAllocationModeFromContext(); - initTypeAndSize(); - - pointer = new FloatPointer(data); - - setIndexer(FloatIndexer.create((FloatPointer) pointer)); - //wrappedBuffer = pointer.asByteBuffer(); - - length = data.length; - underlyingLength = data.length; - } - - public BaseDataBuffer(float[] data, boolean copy, MemoryWorkspace workspace) { - allocationMode = AllocUtil.getAllocationModeFromContext(); - length = data.length; - underlyingLength = data.length; - attached = true; - parentWorkspace = workspace; - - initTypeAndSize(); - - //log.info("Allocating FloatPointer from array of {} elements", data.length); - - pointer = workspace.alloc(data.length * getElementSize(), dataType(), false).asFloatPointer().put(data); - workspaceGenerationId = workspace.getGenerationId(); - setIndexer(FloatIndexer.create((FloatPointer) pointer)); - //wrappedBuffer = pointer.asByteBuffer(); - } - - public BaseDataBuffer(double[] data, boolean copy, MemoryWorkspace workspace) { - allocationMode = AllocUtil.getAllocationModeFromContext(); - length = data.length; - underlyingLength = data.length; - attached = true; - parentWorkspace = workspace; - - initTypeAndSize(); - - //log.info("Allocating FloatPointer from array of {} elements", data.length); - - pointer = workspace.alloc(data.length * getElementSize(), dataType(), false).asDoublePointer().put(data); - workspaceGenerationId = workspace.getGenerationId(); - indexer = DoubleIndexer.create((DoublePointer) pointer); - //wrappedBuffer = pointer.asByteBuffer(); - } - - - public BaseDataBuffer(int[] data, boolean copy, MemoryWorkspace workspace) { - allocationMode = AllocUtil.getAllocationModeFromContext(); - length = data.length; - underlyingLength = data.length; - attached = true; - parentWorkspace = workspace; - - initTypeAndSize(); - - //log.info("Allocating FloatPointer from array of {} elements", data.length); - - pointer = workspace.alloc(data.length * getElementSize(), dataType(), false).asIntPointer().put(data); - workspaceGenerationId = workspace.getGenerationId(); - indexer = IntIndexer.create((IntPointer) pointer); - //wrappedBuffer = pointer.asByteBuffer(); - } - - public BaseDataBuffer(long[] data, boolean copy, MemoryWorkspace workspace) { - allocationMode = AllocUtil.getAllocationModeFromContext(); - length = data.length; - underlyingLength = data.length; - attached = true; - parentWorkspace = workspace; - - initTypeAndSize(); - - //log.info("Allocating FloatPointer from array of {} elements", data.length); - - pointer = workspace.alloc(data.length * getElementSize(), dataType(), false).asLongPointer().put(data); - workspaceGenerationId = workspace.getGenerationId(); - indexer = LongIndexer.create((LongPointer) pointer); - //wrappedBuffer = pointer.asByteBuffer(); - } - - - /** - * - * @param data - * @param copy - */ - public BaseDataBuffer(double[] data, boolean copy, long offset) { - this(data, copy); - this.offset = offset; - this.originalOffset = offset; - this.underlyingLength = data.length; - this.length = underlyingLength - offset; - } - - public BaseDataBuffer(double[] data, boolean copy, long offset, MemoryWorkspace workspace) { - this(data, copy, workspace); - this.offset = offset; - this.originalOffset = offset; - this.underlyingLength = data.length; - this.length = underlyingLength - offset; - } - - /** - * - * @param data - * @param copy - */ - public BaseDataBuffer(double[] data, boolean copy) { - allocationMode = AllocUtil.getAllocationModeFromContext(); - initTypeAndSize(); - - pointer = new DoublePointer(data); - indexer = DoubleIndexer.create((DoublePointer) pointer); - //wrappedBuffer = pointer.asByteBuffer(); - - length = data.length; - underlyingLength = data.length; - } - - - /** - * - * @param data - * @param copy - */ - public BaseDataBuffer(int[] data, boolean copy, long offset) { - this(data, copy); - this.offset = offset; - this.originalOffset = offset; - this.length = data.length - offset; - this.underlyingLength = data.length; - } - - /** - * - * @param data - * @param copy - */ - public BaseDataBuffer(int[] data, boolean copy) { - allocationMode = AllocUtil.getAllocationModeFromContext(); - initTypeAndSize(); - - pointer = new IntPointer(data); - setIndexer(IntIndexer.create((IntPointer) pointer)); - - length = data.length; - underlyingLength = data.length; - - // // log.info("Creating new buffer of size: {}; dtype: {}; B", data.length, dataType()); - } - - /** - * - * @param data - * @param copy - */ - public BaseDataBuffer(long[] data, boolean copy) { - allocationMode = AllocUtil.getAllocationModeFromContext(); - initTypeAndSize(); - - pointer = new LongPointer(data); - setIndexer(LongIndexer.create((LongPointer) pointer)); - - length = data.length; - underlyingLength = data.length; - } - - /** - * - * @param data - */ - public BaseDataBuffer(double[] data) { - this(data, true); - } - - /** - * - * @param data - */ - public BaseDataBuffer(int[] data) { - this(data, true); - } - - /** - * - * @param data - */ - public BaseDataBuffer(float[] data) { - this(data, true); - } - - public BaseDataBuffer(float[] data, MemoryWorkspace workspace) { - this(data, true, workspace); - } - - /** - * - * @param length - * @param elementSize - */ - public BaseDataBuffer(int length, int elementSize, long offset) { - this(length, elementSize); - this.offset = offset; - this.originalOffset = offset; - this.length = length - offset; - this.underlyingLength = length; - } - - /** - * - * @param length - * @param elementSize - */ - public BaseDataBuffer(long length, int elementSize) { - if (length < 1) - throw new IllegalArgumentException("Length must be >= 1"); - initTypeAndSize(); - allocationMode = AllocUtil.getAllocationModeFromContext(); - this.length = length; - this.underlyingLength = length; - this.elementSize = (byte) elementSize; - - if (dataType() == DataType.DOUBLE) { - pointer = new DoublePointer(length); - indexer = DoubleIndexer.create((DoublePointer) pointer); - } else if (dataType() == DataType.FLOAT) { - pointer = new FloatPointer(length); - setIndexer(FloatIndexer.create((FloatPointer) pointer)); - } else if (dataType() == DataType.INT) { - pointer = new IntPointer(length); - setIndexer(IntIndexer.create((IntPointer) pointer)); - } else if (dataType() == DataType.LONG) { - pointer = new LongPointer(length); - setIndexer(LongIndexer.create((LongPointer) pointer)); - } else if (dataType() == DataType.SHORT) { - pointer = new ShortPointer(length); - setIndexer(ShortIndexer.create((ShortPointer) pointer)); - } else if (dataType() == DataType.BYTE) { - pointer = new BytePointer(length); - setIndexer(ByteIndexer.create((BytePointer) pointer)); - } else if (dataType() == DataType.UBYTE) { - pointer = new BytePointer(length); - setIndexer(UByteIndexer.create((BytePointer) pointer)); - } else if (dataType() == DataType.UTF8) { - pointer = new LongPointer(length); - setIndexer(LongIndexer.create((LongPointer) pointer)); - } - - // log.info("Creating new buffer of size: {}; dtype: {}; C", length, dataType()); - } - - /** - * Create a data buffer from - * the given length - * - * @param buffer - * @param length - */ - public BaseDataBuffer(ByteBuffer buffer, long length, long offset) { - this(buffer, length); - this.offset = offset; - this.originalOffset = offset; - this.underlyingLength = length; - this.length = length - offset; - - } - - /** - * Create a data buffer from - * the given length - * - * @param buffer - * @param length - */ - public BaseDataBuffer(ByteBuffer buffer, long length) { - if (length < 1) - throw new IllegalArgumentException("Length must be >= 1"); - initTypeAndSize(); - - this.length = length; - allocationMode = AllocUtil.getAllocationModeFromContext(); - - switch (dataType()){ - case DOUBLE: - pointer = new DoublePointer(buffer.asDoubleBuffer()); - setIndexer(DoubleIndexer.create((DoublePointer) pointer)); - break; - case FLOAT: - pointer = new FloatPointer(buffer.asFloatBuffer()); - setIndexer(FloatIndexer.create((FloatPointer) pointer)); - break; - case HALF: - pointer = new ShortPointer(buffer.asShortBuffer()); - setIndexer(HalfIndexer.create((ShortPointer) pointer)); - break; - case LONG: - pointer = new LongPointer(buffer.asLongBuffer()); - setIndexer(LongIndexer.create((LongPointer) pointer)); - break; - case INT: - pointer = new IntPointer(buffer.asIntBuffer()); - setIndexer(IntIndexer.create((IntPointer) pointer)); - break; - case SHORT: - pointer = new ShortPointer(buffer.asShortBuffer()); - setIndexer(ShortIndexer.create((ShortPointer) pointer)); - break; - case UBYTE: //Fall through - case BYTE: - pointer = new BytePointer(buffer); - setIndexer(UByteIndexer.create((BytePointer)pointer)); - break; - case BOOL: - pointer = new BooleanPointer(length()); - setIndexer(BooleanIndexer.create((BooleanPointer) pointer)); - break; - case UTF8: - pointer = new BytePointer(length()); - setIndexer(ByteIndexer.create((BytePointer) pointer)); - break; - case BFLOAT16: - pointer = new ShortPointer(length()); - setIndexer(Bfloat16Indexer.create((ShortPointer) pointer)); - break; - case UINT16: - pointer = new ShortPointer(length()); - setIndexer(UShortIndexer.create((ShortPointer) pointer)); - break; - case UINT32: - pointer = new IntPointer(length()); - // FIXME: we need unsigned indexer here - setIndexer(IntIndexer.create((IntPointer) pointer)); - break; - case UINT64: - pointer = new LongPointer(length()); - // FIXME: we need unsigned indexer here - setIndexer(LongIndexer.create((LongPointer) pointer)); - break; - } - -// log.info("Creating new buffer of size: {}; dtype: {}; D", length, dataType()); - } //sets the nio wrapped buffer (allows to be overridden for other use cases like cuda) protected void setNioBuffer() { @@ -598,17 +222,6 @@ public abstract class BaseDataBuffer implements DataBuffer { } - - /** - * - * @param data - * @param length - */ - public BaseDataBuffer(byte[] data, long length) { - this(ByteBuffer.wrap(data), length); - } - - /** * Returns the indexer for the buffer * @@ -662,7 +275,6 @@ public abstract class BaseDataBuffer implements DataBuffer { @Override @Deprecated public void persist() { - //isPersist = true; throw new UnsupportedOperationException(); } @@ -678,230 +290,10 @@ public abstract class BaseDataBuffer implements DataBuffer { throw new UnsupportedOperationException(); } - private void fillPointerWithZero() { + protected void fillPointerWithZero() { Pointer.memset(this.pointer(), 0, getElementSize() * length()); } - /** - * Instantiate a buffer with the given length - * - * @param length the length of the buffer - */ - protected BaseDataBuffer(long length) { - this(length, true); - } - - protected BaseDataBuffer(long length, boolean initialize) { - if (length < 0) - throw new IllegalArgumentException("Length must be >= 0"); - initTypeAndSize(); - this.length = length; - this.underlyingLength = length; - allocationMode = AllocUtil.getAllocationModeFromContext(); - if (length < 0) - throw new IllegalArgumentException("Unable to create a buffer of length <= 0"); - - if (dataType() == DataType.DOUBLE) { - pointer = new DoublePointer(length()); - indexer = DoubleIndexer.create((DoublePointer) pointer); - if (initialize) - fillPointerWithZero(); - } else if (dataType() == DataType.FLOAT) { - pointer = new FloatPointer(length()); - setIndexer(FloatIndexer.create((FloatPointer) pointer)); - - if (initialize) - fillPointerWithZero(); - - } else if (dataType() == DataType.HALF) { - pointer = new ShortPointer(length()); - setIndexer(HalfIndexer.create((ShortPointer) pointer)); - - if (initialize) - fillPointerWithZero(); - } else if (dataType() == DataType.BFLOAT16) { - pointer = new ShortPointer(length()); - setIndexer(Bfloat16Indexer.create((ShortPointer) pointer)); - - if (initialize) - fillPointerWithZero(); - } else if (dataType() == DataType.INT) { - pointer = new IntPointer(length()); - setIndexer(IntIndexer.create((IntPointer) pointer)); - if (initialize) - fillPointerWithZero(); - } else if (dataType() == DataType.LONG) { - pointer = new LongPointer(length()); - setIndexer(LongIndexer.create((LongPointer) pointer)); - - if (initialize) - fillPointerWithZero(); - } else if (dataType() == DataType.BYTE) { - pointer = new BytePointer(length()); - setIndexer(ByteIndexer.create((BytePointer) pointer)); - - if (initialize) - fillPointerWithZero(); - } else if (dataType() == DataType.SHORT) { - pointer = new ShortPointer(length()); - setIndexer(ShortIndexer.create((ShortPointer) pointer)); - - if (initialize) - fillPointerWithZero(); - } else if (dataType() == DataType.UBYTE) { - pointer = new BytePointer(length()); - setIndexer(UByteIndexer.create((BytePointer) pointer)); - - if (initialize) - fillPointerWithZero(); - } else if (dataType() == DataType.UINT16) { - pointer = new ShortPointer(length()); - setIndexer(UShortIndexer.create((ShortPointer) pointer)); - - if (initialize) - fillPointerWithZero(); - } else if (dataType() == DataType.UINT32) { - pointer = new IntPointer(length()); - // FIXME: we need unsigned indexer here - setIndexer(IntIndexer.create((IntPointer) pointer)); - - if (initialize) - fillPointerWithZero(); - } else if (dataType() == DataType.UINT64) { - pointer = new LongPointer(length()); - // FIXME: we need unsigned indexer here - setIndexer(LongIndexer.create((LongPointer) pointer)); - - if (initialize) - fillPointerWithZero(); - } else if (dataType() == DataType.BOOL) { - pointer = new BooleanPointer(length()); - setIndexer(BooleanIndexer.create((BooleanPointer) pointer)); - - if (initialize) - fillPointerWithZero(); - } else if (dataType() == DataType.UTF8) { - pointer = new BytePointer(length()); - setIndexer(ByteIndexer.create((BytePointer) pointer)); - - if (initialize) - fillPointerWithZero(); - } - - //// log.info("Creating new buffer of size: {}; dtype: {}; A", length, dataType()); - } - - protected BaseDataBuffer(long length, boolean initialize, MemoryWorkspace workspace) { - if (length < 1) - throw new IllegalArgumentException("Length must be >= 1"); - initTypeAndSize(); - this.length = length; - this.underlyingLength = length; - allocationMode = AllocUtil.getAllocationModeFromContext(); - - - - if (length < 0) - throw new IllegalArgumentException("Unable to create a buffer of length <= 0"); - - if (dataType() == DataType.DOUBLE) { - attached = true; - parentWorkspace = workspace; - - pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asDoublePointer(); //new DoublePointer(length()); - indexer = DoubleIndexer.create((DoublePointer) pointer); - - } else if (dataType() == DataType.FLOAT) { - attached = true; - parentWorkspace = workspace; - - pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asFloatPointer(); //new FloatPointer(length()); - setIndexer(FloatIndexer.create((FloatPointer) pointer)); - - } else if (dataType() == DataType.HALF) { - attached = true; - parentWorkspace = workspace; - - pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asShortPointer(); //new FloatPointer(length()); - setIndexer(HalfIndexer.create((ShortPointer) pointer)); - - } else if (dataType() == DataType.BFLOAT16) { - attached = true; - parentWorkspace = workspace; - - pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asShortPointer(); //new FloatPointer(length()); - setIndexer(Bfloat16Indexer.create((ShortPointer) pointer)); - } else if (dataType() == DataType.INT) { - attached = true; - parentWorkspace = workspace; - - pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asIntPointer(); //new IntPointer(length()); - setIndexer(IntIndexer.create((IntPointer) pointer)); - - } else if (dataType() == DataType.UINT32) { - attached = true; - parentWorkspace = workspace; - - // FIXME: need unsigned indexer here - pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asIntPointer(); //new IntPointer(length()); - setIndexer(IntIndexer.create((IntPointer) pointer)); - - } else if (dataType() == DataType.UINT64) { - attached = true; - parentWorkspace = workspace; - - // FIXME: need unsigned indexer here - pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asLongPointer(); //new IntPointer(length()); - setIndexer(LongIndexer.create((LongPointer) pointer)); - - } else if (dataType() == DataType.LONG) { - attached = true; - parentWorkspace = workspace; - - pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asLongPointer(); //new LongPointer(length()); - setIndexer(LongIndexer.create((LongPointer) pointer)); - } else if (dataType() == DataType.BYTE) { - attached = true; - parentWorkspace = workspace; - - pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asBytePointer(); //new LongPointer(length()); - setIndexer(ByteIndexer.create((BytePointer) pointer)); - } else if (dataType() == DataType.UBYTE) { - attached = true; - parentWorkspace = workspace; - - pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asBytePointer(); //new LongPointer(length()); - setIndexer(UByteIndexer.create((BytePointer) pointer)); - } else if (dataType() == DataType.UINT16) { - attached = true; - parentWorkspace = workspace; - - pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asShortPointer(); //new IntPointer(length()); - setIndexer(UShortIndexer.create((ShortPointer) pointer)); - - } else if (dataType() == DataType.SHORT) { - attached = true; - parentWorkspace = workspace; - - pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asShortPointer(); //new LongPointer(length()); - setIndexer(ShortIndexer.create((ShortPointer) pointer)); - } else if (dataType() == DataType.BOOL) { - attached = true; - parentWorkspace = workspace; - - pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asBoolPointer(); //new LongPointer(length()); - setIndexer(BooleanIndexer.create((BooleanPointer) pointer)); - } else if (dataType() == DataType.UTF8) { - attached = true; - parentWorkspace = workspace; - - pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asLongPointer(); //new LongPointer(length()); - setIndexer(LongIndexer.create((LongPointer) pointer)); - } - - workspaceGenerationId = workspace.getGenerationId(); - - } @Override public void copyAtStride(DataBuffer buf, long n, long stride, long yStride, long offset, long yOffset) { @@ -930,6 +322,9 @@ public abstract class BaseDataBuffer implements DataBuffer { //return referencing; } + public abstract Pointer addressPointer(); + + /* @Override public Pointer addressPointer() { if (released) @@ -937,7 +332,8 @@ public abstract class BaseDataBuffer implements DataBuffer { if (offset() > 0) { Pointer ret; - final long retAddress = pointer().address() + getElementSize() * offset(); + // offset is accounted at native side + final long retAddress = pointer().address(); // directly set address at construction since Pointer.address has not setter. if (dataType() == DataType.DOUBLE) { ret = new DoublePointer(pointer()) { @@ -976,13 +372,14 @@ public abstract class BaseDataBuffer implements DataBuffer { } return pointer(); } + */ @Override public long address() { if (released) throw new IllegalStateException("You can't use DataBuffer once it was released"); - return pointer().address() + getElementSize() * offset(); + return pointer().address(); } @Override @@ -1273,7 +670,7 @@ public abstract class BaseDataBuffer implements DataBuffer { try { UByteIndexer u = (UByteIndexer) indexer; for (int i = 0; i < length(); i++) { - dos.writeByte(u.get(offset() + i)); + dos.writeByte(u.get(i)); } } catch (IOException e) { throw new RuntimeException(e); @@ -1431,29 +828,29 @@ public abstract class BaseDataBuffer implements DataBuffer { } switch (dataType()) { case FLOAT: - return ((FloatIndexer) indexer).get(offset() + i); + return ((FloatIndexer) indexer).get(i); case UINT32: case INT: - return ((IntIndexer) indexer).get(offset() + i); + return ((IntIndexer) indexer).get(i); case BFLOAT16: - return ((Bfloat16Indexer) indexer).get(offset() + i); + return ((Bfloat16Indexer) indexer).get(i); case HALF: - return ((HalfIndexer) indexer).get(offset() + i); + return ((HalfIndexer) indexer).get(i); case UINT16: - return ((UShortIndexer) indexer).get(offset() + i); + return ((UShortIndexer) indexer).get(i); case SHORT: - return ((ShortIndexer) indexer).get(offset() + i); + return ((ShortIndexer) indexer).get(i); case UINT64: case LONG: - return ((LongIndexer) indexer).get(offset() + i); + return ((LongIndexer) indexer).get(i); case BOOL: - return ((BooleanIndexer) indexer).get(offset() + i) ? 1.0 : 0.0; + return ((BooleanIndexer) indexer).get(i) ? 1.0 : 0.0; case DOUBLE: - return ((DoubleIndexer) indexer).get(offset() + i); + return ((DoubleIndexer) indexer).get(i); case BYTE: - return ((ByteIndexer) indexer).get(offset() + i); + return ((ByteIndexer) indexer).get(i); case UBYTE: - return ((UByteIndexer) indexer).get(offset() + i); + return ((UByteIndexer) indexer).get(i); default: throw new UnsupportedOperationException("Cannot get double value from buffer of type " + dataType()); } @@ -1466,29 +863,29 @@ public abstract class BaseDataBuffer implements DataBuffer { switch (dataType()) { case FLOAT: - return (long) ((FloatIndexer) indexer).get(offset() + i); + return (long) ((FloatIndexer) indexer).get(i); case DOUBLE: - return (long) ((DoubleIndexer) indexer).get(offset() + i); + return (long) ((DoubleIndexer) indexer).get(i); case BFLOAT16: - return (long) ((Bfloat16Indexer) indexer).get(offset() + i); + return (long) ((Bfloat16Indexer) indexer).get(i); case HALF: - return (long) ((HalfIndexer) indexer).get(offset() + i); + return (long) ((HalfIndexer) indexer).get( i); case UINT64: case LONG: - return ((LongIndexer) indexer).get(offset() + i); + return ((LongIndexer) indexer).get(i); case UINT32: case INT: - return (long) ((IntIndexer) indexer).get(offset() + i); + return (long) ((IntIndexer) indexer).get(i); case UINT16: - return (long) ((UShortIndexer) indexer).get(offset() + i); + return (long) ((UShortIndexer) indexer).get(i); case SHORT: - return (long) ((ShortIndexer) indexer).get(offset() + i); + return (long) ((ShortIndexer) indexer).get(i); case BYTE: - return (long) ((ByteIndexer) indexer).get(offset() + i); + return (long) ((ByteIndexer) indexer).get(i); case UBYTE: - return (long) ((UByteIndexer) indexer).get(offset() + i); + return (long) ((UByteIndexer) indexer).get(i); case BOOL: - return ((BooleanIndexer) indexer).get(offset() + i) ? 1L : 0L; + return ((BooleanIndexer) indexer).get(i) ? 1L : 0L; default: throw new UnsupportedOperationException("Cannot get long value from buffer of type " + dataType()); } @@ -1505,26 +902,26 @@ public abstract class BaseDataBuffer implements DataBuffer { switch (dataType()) { case DOUBLE: - return (short) ((DoubleIndexer) indexer).get(offset() + i); + return (short) ((DoubleIndexer) indexer).get(i); case BFLOAT16: - return (short) ((Bfloat16Indexer) indexer).get(offset() + i); + return (short) ((Bfloat16Indexer) indexer).get(i); case HALF: - return (short) ((HalfIndexer) indexer).get(offset() + i); + return (short) ((HalfIndexer) indexer).get(i); case BOOL: - return (short) (((BooleanIndexer) indexer).get(offset() + i) ? 1 : 0); + return (short) (((BooleanIndexer) indexer).get(i) ? 1 : 0); case UINT32: case INT: - return (short) ((IntIndexer) indexer).get(offset() + i); + return (short) ((IntIndexer) indexer).get(i); case UINT16: case SHORT: - return ((ShortIndexer) indexer).get(offset() + i); + return ((ShortIndexer) indexer).get(i); case BYTE: - return (short) ((ByteIndexer) indexer).get(offset() + i); + return (short) ((ByteIndexer) indexer).get(i); case UINT64: case LONG: - return (short) ((LongIndexer) indexer).get(offset() + i); + return (short) ((LongIndexer) indexer).get(i); case FLOAT: - return (short) ((FloatIndexer) indexer).get(offset() + i); + return (short) ((FloatIndexer) indexer).get(i); default: throw new UnsupportedOperationException("Cannot get short value from buffer of type " + dataType()); } @@ -1546,29 +943,29 @@ public abstract class BaseDataBuffer implements DataBuffer { switch (dataType()) { case DOUBLE: - return (float) ((DoubleIndexer) indexer).get(offset() + i); + return (float) ((DoubleIndexer) indexer).get(i); case BOOL: - return ((BooleanIndexer) indexer).get(offset() + i) ? 1.f : 0.f; + return ((BooleanIndexer) indexer).get(i) ? 1.f : 0.f; case UINT32: case INT: - return (float) ((IntIndexer) indexer).get(offset() + i); + return (float) ((IntIndexer) indexer).get(i); case UINT16: - return ((UShortIndexer) indexer).get(offset() + i); + return ((UShortIndexer) indexer).get(i); case SHORT: - return (float) ((ShortIndexer) indexer).get(offset() + i); + return (float) ((ShortIndexer) indexer).get(i); case BFLOAT16: - return ((Bfloat16Indexer) indexer).get(offset() + i); + return ((Bfloat16Indexer) indexer).get(i); case HALF: - return ((HalfIndexer) indexer).get(offset() + i); + return ((HalfIndexer) indexer).get(i); case UBYTE: - return (float) ((UByteIndexer) indexer).get(offset() + i); + return (float) ((UByteIndexer) indexer).get(i); case BYTE: - return (float) ((ByteIndexer) indexer).get(offset() + i); + return (float) ((ByteIndexer) indexer).get(i); case UINT64: case LONG: - return (float) ((LongIndexer) indexer).get(offset() + i); + return (float) ((LongIndexer) indexer).get(i); case FLOAT: - return ((FloatIndexer) indexer).get(offset() + i); + return ((FloatIndexer) indexer).get(i); default: throw new UnsupportedOperationException("Cannot get float value from buffer of type " + dataType()); } @@ -1581,29 +978,29 @@ public abstract class BaseDataBuffer implements DataBuffer { switch (dataType()) { case DOUBLE: - return (int) ((DoubleIndexer) indexer).get(offset() + i); + return (int) ((DoubleIndexer) indexer).get(i); case BOOL: - return ((BooleanIndexer) indexer).get(offset() + i) ? 1 : 0; + return ((BooleanIndexer) indexer).get(i) ? 1 : 0; case UINT32: case INT: - return ((IntIndexer) indexer).get(offset() + i); + return ((IntIndexer) indexer).get(i); case BFLOAT16: - return (int) ((Bfloat16Indexer) indexer).get(offset() + i); + return (int) ((Bfloat16Indexer) indexer).get(i); case HALF: - return (int) ((HalfIndexer) indexer).get(offset() + i); + return (int) ((HalfIndexer) indexer).get(i); case UINT16: - return ((UShortIndexer) indexer).get(offset() + i); + return ((UShortIndexer) indexer).get(i); case SHORT: - return ((ShortIndexer) indexer).get(offset() + i); + return ((ShortIndexer) indexer).get(i); case UBYTE: - return ((UByteIndexer) indexer).get(offset() + i); + return ((UByteIndexer) indexer).get(i); case BYTE: - return ((ByteIndexer) indexer).get(offset() + i); + return ((ByteIndexer) indexer).get(i); case UINT64: case LONG: - return (int) ((LongIndexer) indexer).get(offset() + i); + return (int) ((LongIndexer) indexer).get(i); case FLOAT: - return (int) ((FloatIndexer) indexer).get(offset() + i); + return (int) ((FloatIndexer) indexer).get(i); default: throw new UnsupportedOperationException("Cannot get integer value from buffer of type " + dataType()); } @@ -1623,79 +1020,7 @@ public abstract class BaseDataBuffer implements DataBuffer { return getFloat(i); } - public void pointerIndexerByCurrentType(DataType currentType) { - switch (currentType) { - case UINT64: - pointer = new LongPointer(length()); - setIndexer(LongIndexer.create((LongPointer) pointer)); - type = DataType.UINT64; - break; - case LONG: - pointer = new LongPointer(length()); - setIndexer(LongIndexer.create((LongPointer) pointer)); - type = DataType.LONG; - break; - case UINT32: - pointer = new IntPointer(length()); - setIndexer(IntIndexer.create((IntPointer) pointer)); - type = DataType.UINT32; - break; - case INT: - pointer = new IntPointer(length()); - setIndexer(IntIndexer.create((IntPointer) pointer)); - type = DataType.INT; - break; - case UINT16: - pointer = new ShortPointer(length()); - setIndexer(UShortIndexer.create((ShortPointer) pointer)); - type = DataType.UINT16; - break; - case SHORT: - pointer = new ShortPointer(length()); - setIndexer(ShortIndexer.create((ShortPointer) pointer)); - type = DataType.SHORT; - break; - case UBYTE: - pointer = new BytePointer(length()); - setIndexer(UByteIndexer.create((BytePointer) pointer)); - type = DataType.UBYTE; - break; - case BYTE: - pointer = new BytePointer(length()); - setIndexer(ByteIndexer.create((BytePointer) pointer)); - type = DataType.BYTE; - break; - case DOUBLE: - pointer = new DoublePointer(length()); - indexer = DoubleIndexer.create((DoublePointer) pointer); - type = DataType.DOUBLE; - break; - case FLOAT: - pointer = new FloatPointer(length()); - setIndexer(FloatIndexer.create((FloatPointer) pointer)); - type = DataType.FLOAT; - break; - case BFLOAT16: - pointer = new ShortPointer(length()); - setIndexer(Bfloat16Indexer.create((ShortPointer) pointer)); - type = DataType.BFLOAT16; - break; - case HALF: - pointer = new ShortPointer(length()); - setIndexer(HalfIndexer.create((ShortPointer) pointer)); - type = DataType.HALF; - break; - case BOOL: - pointer = new BooleanPointer(length()); - setIndexer(BooleanIndexer.create((BooleanPointer) pointer)); - type = DataType.BOOL; - break; - case COMPRESSED: - break; - default: - throw new UnsupportedOperationException(); - } - } + public abstract void pointerIndexerByCurrentType(DataType currentType); public void putByDestinationType(long i, Number element, DataType globalType) { if (globalType == DataType.INT || type == DataType.INT || globalType == DataType.UINT16 || globalType == DataType.UBYTE || globalType == DataType.SHORT|| globalType == DataType.BYTE || globalType == DataType.BOOL) { @@ -1722,47 +1047,43 @@ public abstract class BaseDataBuffer implements DataBuffer { switch (dataType()) { case BOOL: - ((BooleanIndexer) indexer).put(offset() + i, element == 0.0 ? false : true); + ((BooleanIndexer) indexer).put(i, element == 0.0 ? false : true); break; case BYTE: - ((ByteIndexer) indexer).put(offset() + i, (byte) element); + ((ByteIndexer) indexer).put(i, (byte) element); break; case UBYTE: - ((UByteIndexer) indexer).put(offset() + i, (int) element); + ((UByteIndexer) indexer).put(i, (int) element); break; case UINT16: - ((UShortIndexer) indexer).put(offset() + i, (int)element); + ((UShortIndexer) indexer).put(i, (int)element); break; case SHORT: - ((ShortIndexer) indexer).put(offset() + i, (short) element); + ((ShortIndexer) indexer).put(i, (short) element); break; case UINT32: case INT: - ((IntIndexer) indexer).put(offset() + i, (int) element); + ((IntIndexer) indexer).put(i, (int) element); break; case UINT64: case LONG: - ((LongIndexer) indexer).put(offset() + i, (long) element); + ((LongIndexer) indexer).put(i, (long) element); break; case BFLOAT16: - ((Bfloat16Indexer) indexer).put(offset() + i, element); + ((Bfloat16Indexer) indexer).put(i, element); break; case HALF: - ((HalfIndexer) indexer).put(offset() + i, element); + ((HalfIndexer) indexer).put(i, element); break; case FLOAT: - ((FloatIndexer) indexer).put(offset() + i, element); + ((FloatIndexer) indexer).put(i, element); break; case DOUBLE: - ((DoubleIndexer) indexer).put(offset() + i, element); + ((DoubleIndexer) indexer).put(i, element); break; default: throw new IllegalStateException("Unsupported type: " + dataType()); } - - if (i == length) { - length++; - } } @Override @@ -1772,47 +1093,43 @@ public abstract class BaseDataBuffer implements DataBuffer { switch (dataType()) { case BOOL: - ((BooleanIndexer) indexer).put(offset() + i, element > 0.0); + ((BooleanIndexer) indexer).put(i, element > 0.0); break; case BYTE: - ((ByteIndexer) indexer).put(offset() + i, (byte) element); + ((ByteIndexer) indexer).put(i, (byte) element); break; case UBYTE: - ((UByteIndexer) indexer).put(offset() + i, (short) element); + ((UByteIndexer) indexer).put(i, (short) element); break; case UINT16: - ((UShortIndexer) indexer).put(offset() + i, (int) element); + ((UShortIndexer) indexer).put(i, (int) element); break; case SHORT: - ((ShortIndexer) indexer).put(offset() + i, (short) element); + ((ShortIndexer) indexer).put(i, (short) element); break; case UINT32: case INT: - ((IntIndexer) indexer).put(offset() + i, (int) element); + ((IntIndexer) indexer).put(i, (int) element); break; case UINT64: case LONG: - ((LongIndexer) indexer).put(offset() + i, (long) element); + ((LongIndexer) indexer).put(i, (long) element); break; case BFLOAT16: - ((Bfloat16Indexer) indexer).put(offset() + i, (float) element); + ((Bfloat16Indexer) indexer).put(i, (float) element); break; case HALF: - ((HalfIndexer) indexer).put(offset() + i, (float) element); + ((HalfIndexer) indexer).put(i, (float) element); break; case FLOAT: - ((FloatIndexer) indexer).put(offset() + i, (float) element); + ((FloatIndexer) indexer).put(i, (float) element); break; case DOUBLE: - ((DoubleIndexer) indexer).put(offset() + i, element); + ((DoubleIndexer) indexer).put(i, element); break; default: throw new UnsupportedOperationException("Unsupported data type: " + dataType()); } - - if (i == length) { - length++; - } } @Override @@ -1822,47 +1139,43 @@ public abstract class BaseDataBuffer implements DataBuffer { switch (dataType()) { case BOOL: - ((BooleanIndexer) indexer).put(offset() + i, element == 0 ? false : true); + ((BooleanIndexer) indexer).put(i, element == 0 ? false : true); break; case BYTE: - ((ByteIndexer) indexer).put(offset() + i, (byte) element); + ((ByteIndexer) indexer).put(i, (byte) element); break; case UBYTE: - ((UByteIndexer) indexer).put(offset() + i, element); + ((UByteIndexer) indexer).put(i, element); break; case UINT16: - ((UShortIndexer) indexer).put(offset() + i, element); + ((UShortIndexer) indexer).put(i, element); break; case SHORT: - ((ShortIndexer) indexer).put(offset() + i, (short) element); + ((ShortIndexer) indexer).put(i, (short) element); break; case UINT32: case INT: - ((IntIndexer) indexer).put(offset() + i, element); + ((IntIndexer) indexer).put(i, element); break; case UINT64: case LONG: - ((LongIndexer) indexer).put(offset() + i, element); + ((LongIndexer) indexer).put(i, element); break; case BFLOAT16: - ((Bfloat16Indexer) indexer).put(offset() + i, element); + ((Bfloat16Indexer) indexer).put(i, element); break; case HALF: - ((HalfIndexer) indexer).put(offset() + i, element); + ((HalfIndexer) indexer).put(i, element); break; case FLOAT: - ((FloatIndexer) indexer).put(offset() + i, element); + ((FloatIndexer) indexer).put(i, element); break; case DOUBLE: - ((DoubleIndexer) indexer).put(offset() + i, element); + ((DoubleIndexer) indexer).put(i, element); break; default: throw new UnsupportedOperationException("Unsupported data type: " + dataType()); } - - if (i == length) { - length++; - } } @Override @@ -1872,47 +1185,43 @@ public abstract class BaseDataBuffer implements DataBuffer { switch (dataType()) { case BOOL: - ((BooleanIndexer) indexer).put(offset() + i, element); + ((BooleanIndexer) indexer).put(i, element); break; case BYTE: - ((ByteIndexer) indexer).put(offset() + i, element ? (byte)1 : (byte) 0); + ((ByteIndexer) indexer).put(i, element ? (byte)1 : (byte) 0); break; case UBYTE: - ((UByteIndexer) indexer).put(offset() + i, element ? (byte)1 : (byte) 0); + ((UByteIndexer) indexer).put(i, element ? (byte)1 : (byte) 0); break; case UINT16: - ((UShortIndexer) indexer).put(offset() + i, element ? 1 : 0); + ((UShortIndexer) indexer).put(i, element ? 1 : 0); break; case SHORT: - ((ShortIndexer) indexer).put(offset() + i, element ? (short) 1 : (short) 0); + ((ShortIndexer) indexer).put(i, element ? (short) 1 : (short) 0); break; case INT: case UINT32: - ((IntIndexer) indexer).put(offset() + i, element ? 1 : 0); + ((IntIndexer) indexer).put(i, element ? 1 : 0); break; case UINT64: case LONG: - ((LongIndexer) indexer).put(offset() + i, element ? 1 : 0); + ((LongIndexer) indexer).put(i, element ? 1 : 0); break; case BFLOAT16: - ((Bfloat16Indexer) indexer).put(offset() + i, element ? 1.0f : 0.0f); + ((Bfloat16Indexer) indexer).put(i, element ? 1.0f : 0.0f); break; case HALF: - ((HalfIndexer) indexer).put(offset() + i, element ? 1.0f : 0.0f); + ((HalfIndexer) indexer).put(i, element ? 1.0f : 0.0f); break; case FLOAT: - ((FloatIndexer) indexer).put(offset() + i, element ? 1.0f : 0.0f); + ((FloatIndexer) indexer).put(i, element ? 1.0f : 0.0f); break; case DOUBLE: - ((DoubleIndexer) indexer).put(offset() + i, element ? 1.0 : 0.0); + ((DoubleIndexer) indexer).put(i, element ? 1.0 : 0.0); break; default: throw new UnsupportedOperationException("Unsupported data type: " + dataType()); } - - if (i == length) { - length++; - } } @Override @@ -1922,47 +1231,43 @@ public abstract class BaseDataBuffer implements DataBuffer { switch (dataType()) { case BOOL: - ((BooleanIndexer) indexer).put(offset() + i, element == 0 ? false : true); + ((BooleanIndexer) indexer).put(i, element == 0 ? false : true); break; case BYTE: - ((ByteIndexer) indexer).put(offset() + i, (byte) element); + ((ByteIndexer) indexer).put(i, (byte) element); break; case UBYTE: - ((UByteIndexer) indexer).put(offset() + i, (short) element); + ((UByteIndexer) indexer).put(i, (short) element); break; case UINT16: - ((UShortIndexer) indexer).put(offset() + i, (int) element); + ((UShortIndexer) indexer).put(i, (int) element); break; case SHORT: - ((ShortIndexer) indexer).put(offset() + i, (short) element); + ((ShortIndexer) indexer).put(i, (short) element); break; case UINT32: case INT: - ((IntIndexer) indexer).put(offset() + i, (int) element); + ((IntIndexer) indexer).put(i, (int) element); break; case UINT64: case LONG: - ((LongIndexer) indexer).put(offset() + i, element); + ((LongIndexer) indexer).put(i, element); break; case BFLOAT16: - ((Bfloat16Indexer) indexer).put(offset() + i, (float) element); + ((Bfloat16Indexer) indexer).put(i, (float) element); break; case HALF: - ((HalfIndexer) indexer).put(offset() + i, (float) element); + ((HalfIndexer) indexer).put(i, (float) element); break; case FLOAT: - ((FloatIndexer) indexer).put(offset() + i, (float) element); + ((FloatIndexer) indexer).put(i, (float) element); break; case DOUBLE: - ((DoubleIndexer) indexer).put(offset() + i, (double) element); + ((DoubleIndexer) indexer).put(i, (double) element); break; default: throw new UnsupportedOperationException("Unsupported data type: " + dataType()); } - - if (i == length) { - length++; - } } @Override @@ -2507,31 +1812,6 @@ public abstract class BaseDataBuffer implements DataBuffer { return originalOffset; } - /** - * Returns tracking point for Allocator - * - * PLEASE NOTE: Suitable & meaningful only for specific backends - * - * @return - */ - @Override - public Long getTrackingPoint() { - if (underlyingDataBuffer() != this) - return underlyingDataBuffer() == null ? trackingPoint : underlyingDataBuffer().getTrackingPoint(); - return trackingPoint; - } - - /** - * Sets tracking point used by Allocator - * - * PLEASE NOTE: Suitable & meaningful only for specific backends - * - * @param trackingPoint - */ - public void setTrackingPoint(Long trackingPoint) { - this.trackingPoint = trackingPoint; - } - /** * This method returns whether this DataBuffer is constant, or not. * Constant buffer means that it modified only during creation time, and then it stays the same for all lifecycle. I.e. used in shape info databuffers. @@ -2595,63 +1875,7 @@ public abstract class BaseDataBuffer implements DataBuffer { return null; } - /** - * Reallocate the native memory of the buffer - * @param length the new length of the buffer - * @return this databuffer - * */ - @Override - public DataBuffer reallocate(long length) { - - Pointer oldPointer = pointer; - if (isAttached()) { - long capacity = length * getElementSize(); - switch (dataType()) { - case DOUBLE: - pointer = getParentWorkspace().alloc(capacity, DataType.DOUBLE, false).asDoublePointer(); - indexer = DoubleIndexer.create((DoublePointer) pointer); - break; - case FLOAT: - pointer = getParentWorkspace().alloc(capacity, DataType.FLOAT, false).asFloatPointer(); - indexer = FloatIndexer.create((FloatPointer) pointer); - break; - case INT: - pointer = getParentWorkspace().alloc(capacity, DataType.INT, false).asIntPointer(); - indexer = IntIndexer.create((IntPointer) pointer); - break; - case LONG: - pointer = getParentWorkspace().alloc(capacity, DataType.LONG, false).asLongPointer(); - indexer = LongIndexer.create((LongPointer) pointer); - break; - } - - workspaceGenerationId = getParentWorkspace().getGenerationId(); - } else { - switch (dataType()) { - case INT: - pointer = new IntPointer(length); - indexer = IntIndexer.create((IntPointer) pointer); - break; - case DOUBLE: - pointer = new DoublePointer(length); - indexer = DoubleIndexer.create((DoublePointer) pointer); - break; - case FLOAT: - pointer = new FloatPointer(length); - indexer = FloatIndexer.create((FloatPointer) pointer); - break; - case LONG: - pointer = new LongPointer(length); - indexer = LongIndexer.create((LongPointer) pointer); - break; - } - } - - Pointer.memcpy(pointer, oldPointer, this.length() * getElementSize()); - this.underlyingLength = length; - this.length = length; - return this; - } + public abstract DataBuffer reallocate(long length); /** * @return the capacity of the buffer diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataBuffer.java b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataBuffer.java index 9b1c2ecec..303f6383d 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataBuffer.java +++ b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataBuffer.java @@ -622,23 +622,6 @@ public interface DataBuffer extends Serializable, AutoCloseable { */ void read(InputStream is, AllocationMode allocationMode, long length, DataType dataType); - /** - * Returns tracking point for Allocator - * - * PLEASE NOTE: Suitable & meaningful only for specific backends - * @return - */ - Long getTrackingPoint(); - - /** - * Sets tracking point used by Allocator - * - * PLEASE NOTE: Suitable & meaningful only for specific backends - * - * @param trackingPoint - */ - void setTrackingPoint(Long trackingPoint); - /** * This method returns whether this DataBuffer is constant, or not. * Constant buffer means that it modified only during creation time, and then it stays the same for all lifecycle. I.e. used in shape info databuffers. diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataType.java b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataType.java index 8e82184c4..7555bce21 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataType.java +++ b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataType.java @@ -17,14 +17,28 @@ package org.nd4j.linalg.api.buffer; public enum DataType { + DOUBLE, FLOAT, + + @Deprecated HALF, + + @Deprecated LONG, + + @Deprecated INT, + + @Deprecated SHORT, + + @Deprecated UBYTE, + + @Deprecated BYTE, + BOOL, UTF8, COMPRESSED, @@ -34,6 +48,13 @@ public enum DataType { UINT64, UNKNOWN; + public static final DataType FLOAT16 = DataType.HALF; + public static final DataType INT32 = DataType.INT; + public static final DataType INT64 = DataType.LONG; + public static final DataType INT16 = DataType.SHORT; + public static final DataType INT8 = DataType.BYTE; + public static final DataType UINT8 = DataType.UBYTE; + public static DataType fromInt(int type) { switch (type) { diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/factory/DataBufferFactory.java b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/factory/DataBufferFactory.java index 743f34655..abb674499 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/factory/DataBufferFactory.java +++ b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/factory/DataBufferFactory.java @@ -25,6 +25,7 @@ import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; +import java.nio.Buffer; import java.nio.ByteBuffer; /** @@ -60,30 +61,13 @@ public interface DataBufferFactory { DataBuffer create(DataBuffer underlyingBuffer, long offset, long length); /** - * Create int buffer - * @param buffer + * Creates a DataBuffer from java.nio.ByteBuffer + * @param underlyingBuffer + * @param offset * @param length * @return */ - DataBuffer createInt(long offset, ByteBuffer buffer, int length); - - /** - * Create a float data buffer - * @param buffer - * @param length - * @return - */ - DataBuffer createFloat(long offset, ByteBuffer buffer, int length); - - /** - * Creates a double data buffer - * @param buffer - * @param length - * @return - */ - DataBuffer createDouble(long offset, ByteBuffer buffer, int length); - - DataBuffer createLong(ByteBuffer buffer, int length); + DataBuffer create(ByteBuffer underlyingBuffer, DataType type, long length, long offset); /** * Create a double data buffer @@ -289,31 +273,6 @@ public interface DataBufferFactory { */ DataBuffer createInt(long offset, float[] data, boolean copy); - - /** - * Create int buffer - * @param buffer - * @param length - * @return - */ - DataBuffer createInt(ByteBuffer buffer, int length); - - /** - * Create a float data buffer - * @param buffer - * @param length - * @return - */ - DataBuffer createFloat(ByteBuffer buffer, int length); - - /** - * Creates a double data buffer - * @param buffer - * @param length - * @return - */ - DataBuffer createDouble(ByteBuffer buffer, int length); - /** * Create a double data buffer * @@ -459,22 +418,6 @@ public interface DataBufferFactory { DataBuffer createDouble(double[] data); - /** - * Create a double buffer - * @param data - * @param length - * @return - */ - DataBuffer createDouble(byte[] data, int length); - - /** - * Create a double buffer - * @param data - * @param length - * @return - */ - DataBuffer createFloat(byte[] data, int length); - /** * Creates a float data buffer * @@ -816,14 +759,6 @@ public interface DataBufferFactory { */ DataBuffer createHalf(int[] data); - /** - * Creates a half-precision data buffer - * - * @param data the data to create the buffer from - * @return the new buffer - */ - DataBuffer createHalf(long offset, byte[] data, int length); - /** * Creates a half-precision data buffer * @@ -831,22 +766,6 @@ public interface DataBufferFactory { */ DataBuffer createHalf(long offset, int length); - /** - * Creates a half-precision data buffer - * - * @return the new buffer - */ - DataBuffer createHalf(ByteBuffer buffer, int length); - - /** - * Creates a half-precision data buffer - * - * @param data - * @param length - * @return - */ - DataBuffer createHalf(byte[] data, int length); - Class intBufferClass(); @@ -858,4 +777,5 @@ public interface DataBufferFactory { Class doubleBufferClass(); + DataBuffer createUtf8Buffer(byte[] data, long product); } diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/pointers/PagedPointer.java b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/pointers/PagedPointer.java index 74d4bbca2..35d36ec64 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/pointers/PagedPointer.java +++ b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/pointers/PagedPointer.java @@ -63,7 +63,7 @@ public class PagedPointer extends Pointer { public PagedPointer(Pointer pointer, long capacity) { this.originalPointer = pointer; - this.address = pointer.address(); + this.address = pointer == null ? 0 : pointer.address(); this.capacity = capacity; this.limit = capacity; diff --git a/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/pom.xml b/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/pom.xml index 2d0cd6afc..7b96eb072 100644 --- a/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/pom.xml +++ b/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/pom.xml @@ -60,13 +60,35 @@ testresources - + + + + nd4j-testresources + + + + nd4j-tests-cpu + + false + org.nd4j nd4j-native ${project.version} - test + + + + + nd4j-tests-cuda + + false + + + + org.nd4j + nd4j-cuda-10.2 + ${project.version} diff --git a/nd4j/nd4j-jdbc/pom.xml b/nd4j/nd4j-jdbc/pom.xml index 05ef09942..a382cf0e8 100644 --- a/nd4j/nd4j-jdbc/pom.xml +++ b/nd4j/nd4j-jdbc/pom.xml @@ -53,6 +53,18 @@ testresources + + + nd4j-testresources + + + + nd4j-tests-cpu + + + + nd4j-tests-cuda + diff --git a/nd4s/src/main/scala/org/nd4s/ops/FunctionalOpExecutioner.scala b/nd4s/src/main/scala/org/nd4s/ops/FunctionalOpExecutioner.scala index 826264b8f..34136f61a 100644 --- a/nd4s/src/main/scala/org/nd4s/ops/FunctionalOpExecutioner.scala +++ b/nd4s/src/main/scala/org/nd4s/ops/FunctionalOpExecutioner.scala @@ -18,7 +18,7 @@ package org.nd4s.ops import java.util.{ List, Map, Properties } import org.bytedeco.javacpp.Pointer -import org.nd4j.linalg.api.buffer.{ DataBuffer, DataType, Utf8Buffer } +import org.nd4j.linalg.api.buffer.{ DataBuffer, DataType } import org.nd4j.linalg.api.environment.Nd4jEnvironment import org.nd4j.linalg.api.ndarray.{ INDArray, INDArrayStatistics } import org.nd4j.linalg.api.ops.aggregates.{ Aggregate, Batch } @@ -452,7 +452,7 @@ class FunctionalOpExecutioner extends OpExecutioner { * @param index * @return */ - def getString(buffer: Utf8Buffer, index: Long): String = ??? + def getString(buffer: DataBuffer, index: Long): String = ??? /** * This method returns OpContext which can be used (and reused) to execute custom ops diff --git a/nd4s/src/test/scala/org/nd4s/samediff/ConstructionTest.scala b/nd4s/src/test/scala/org/nd4s/samediff/ConstructionTest.scala index 25e8f374f..d1d760286 100644 --- a/nd4s/src/test/scala/org/nd4s/samediff/ConstructionTest.scala +++ b/nd4s/src/test/scala/org/nd4s/samediff/ConstructionTest.scala @@ -122,15 +122,15 @@ class ConstructionTest extends FlatSpec with Matchers { val learning_rate = 0.1 val seed = 7 - val target = Nd4j.createUninitialized(1000) + val target = Nd4j.createUninitialized(DataType.DOUBLE, 1000) val rng = Nd4j.getRandom rng.setSeed(seed) val x1_label1 = Nd4j.randn(3.0, 1.0, target, rng) - val target1 = Nd4j.createUninitialized(1000) + val target1 = Nd4j.createUninitialized(DataType.DOUBLE, 1000) val x2_label1 = Nd4j.randn(2.0, 1.0, target1, rng) - val target2 = Nd4j.createUninitialized(1000) + val target2 = Nd4j.createUninitialized(DataType.DOUBLE, 1000) val x1_label2 = Nd4j.randn(7.0, 1.0, target2, rng) - val target3 = Nd4j.createUninitialized(1000) + val target3 = Nd4j.createUninitialized(DataType.DOUBLE, 1000) val x2_label2 = Nd4j.randn(6.0, 1.0, target3, rng) // np.append, was not able to guess proper method