diff --git a/libnd4j/blas/NativeOps.h b/libnd4j/blas/NativeOps.h index 11ef2f3dc..01a9d900a 100755 --- a/libnd4j/blas/NativeOps.h +++ b/libnd4j/blas/NativeOps.h @@ -1700,6 +1700,7 @@ public: nd4j::ConstantDataBuffer* constantBuffer(nd4j::DataType dtype, nd4j::ConstantDescriptor *descriptor); void deleteShapeBuffer(Nd4jPointer ptr); + void deleteTadPack(Nd4jPointer ptr); const char* runLightBenchmarkSuit(bool printOut); const char* runFullBenchmarkSuit(bool printOut); diff --git a/libnd4j/blas/cpu/NativeOps.cpp b/libnd4j/blas/cpu/NativeOps.cpp index 14a2538af..871d7cdae 100644 --- a/libnd4j/blas/cpu/NativeOps.cpp +++ b/libnd4j/blas/cpu/NativeOps.cpp @@ -2705,6 +2705,11 @@ void NativeOps::deleteShapeBuffer(Nd4jPointer ptr) { delete buffer; } +void NativeOps::deleteTadPack(Nd4jPointer ptr) { + auto buffer = reinterpret_cast(ptr); + delete buffer; +} + nd4j::ConstantDataBuffer* NativeOps::constantBuffer(nd4j::DataType dtype, Nd4jLong *data, int length) { return nullptr; } diff --git a/libnd4j/blas/cuda/NativeOps.cu b/libnd4j/blas/cuda/NativeOps.cu index ba271e5d8..d045ee16b 100755 --- a/libnd4j/blas/cuda/NativeOps.cu +++ b/libnd4j/blas/cuda/NativeOps.cu @@ -3246,6 +3246,11 @@ void NativeOps::deleteShapeBuffer(Nd4jPointer ptr) { delete buffer; } +void NativeOps::deleteTadPack(Nd4jPointer ptr) { + auto buffer = reinterpret_cast(ptr); + delete buffer; +} + nd4j::ConstantDataBuffer* NativeOps::constantBuffer(nd4j::DataType dtype, Nd4jLong *data, int length) { return nd4j::ConstantHelper::getInstance()->constantBuffer(ConstantDescriptor(data, length), dtype); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java index a93776528..a99ef143d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java @@ -1122,6 +1122,8 @@ public abstract class NativeOps extends Pointer { public abstract void deleteShapeBuffer(Pointer state); + public abstract void deleteTadPack(Pointer pointer); + public abstract void deleteGraphState(Pointer state); public abstract int estimateThreshold(PointerPointer extraPointers, Pointer x, LongPointer xShapeInfo, int N, float threshold); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index 5dcc798ad..3b06f2ec0 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -2600,6 +2600,8 @@ public class CudaExecutioner extends DefaultOpExecutioner { val tadShape = new CudaLongDataBuffer(pack.primaryShapeInfo(), pack.specialShapeInfo(), pack.shapeInfoLength()); val tadOffsets = new CudaLongDataBuffer(pack.primaryOffsets(), pack.specialOffsets(), pack.numberOfTads()); + nativeOps.deleteTadPack(pack); + return new TadPack(tadShape, tadOffsets); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index 364d393a9..6cc3516ed 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -3048,6 +3048,7 @@ public static class NativeOps extends org.nd4j.nativeblas.NativeOps { public native ConstantDataBuffer constantBuffer(@Cast("nd4j::DataType") int dtype, ConstantDescriptor descriptor); public native void deleteShapeBuffer(@Cast("Nd4jPointer") Pointer ptr); + public native void deleteTadPack(@Cast("Nd4jPointer") Pointer ptr); public native @Cast("char*") String runLightBenchmarkSuit(@Cast("bool") boolean printOut); public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean printOut); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java index 0dc1e5070..cfccac828 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java @@ -2176,6 +2176,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { val tadShape = new LongBuffer(pack.primaryShapeInfo(), pack.shapeInfoLength()); val tadOffsets = new LongBuffer(pack.primaryOffsets(), pack.numberOfTads()); + loop.deleteTadPack(pack); + return new TadPack(tadShape, tadOffsets); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index ed82a05e3..0ce6fc210 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -3048,6 +3048,7 @@ public static class NativeOps extends org.nd4j.nativeblas.NativeOps { public native ConstantDataBuffer constantBuffer(@Cast("nd4j::DataType") int dtype, ConstantDescriptor descriptor); public native void deleteShapeBuffer(@Cast("Nd4jPointer") Pointer ptr); + public native void deleteTadPack(@Cast("Nd4jPointer") Pointer ptr); public native @Cast("char*") String runLightBenchmarkSuit(@Cast("bool") boolean printOut); public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean printOut);