delete temporary TadPack C++/Java side (#74)

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-07-22 15:55:28 +03:00 committed by AlexDBlack
parent 59a006ce29
commit b9708be5db
8 changed files with 19 additions and 0 deletions

View File

@ -1700,6 +1700,7 @@ public:
nd4j::ConstantDataBuffer* constantBuffer(nd4j::DataType dtype, nd4j::ConstantDescriptor *descriptor);
void deleteShapeBuffer(Nd4jPointer ptr);
void deleteTadPack(Nd4jPointer ptr);
const char* runLightBenchmarkSuit(bool printOut);
const char* runFullBenchmarkSuit(bool printOut);

View File

@ -2705,6 +2705,11 @@ void NativeOps::deleteShapeBuffer(Nd4jPointer ptr) {
delete buffer;
}
void NativeOps::deleteTadPack(Nd4jPointer ptr) {
auto buffer = reinterpret_cast<nd4j::TadPack*>(ptr);
delete buffer;
}
nd4j::ConstantDataBuffer* NativeOps::constantBuffer(nd4j::DataType dtype, Nd4jLong *data, int length) {
return nullptr;
}

View File

@ -3246,6 +3246,11 @@ void NativeOps::deleteShapeBuffer(Nd4jPointer ptr) {
delete buffer;
}
void NativeOps::deleteTadPack(Nd4jPointer ptr) {
auto buffer = reinterpret_cast<nd4j::TadPack*>(ptr);
delete buffer;
}
nd4j::ConstantDataBuffer* NativeOps::constantBuffer(nd4j::DataType dtype, Nd4jLong *data, int length) {
return nd4j::ConstantHelper::getInstance()->constantBuffer(ConstantDescriptor(data, length), dtype);
}

View File

@ -1122,6 +1122,8 @@ public abstract class NativeOps extends Pointer {
public abstract void deleteShapeBuffer(Pointer state);
public abstract void deleteTadPack(Pointer pointer);
public abstract void deleteGraphState(Pointer state);
public abstract int estimateThreshold(PointerPointer extraPointers, Pointer x, LongPointer xShapeInfo, int N, float threshold);

View File

@ -2600,6 +2600,8 @@ public class CudaExecutioner extends DefaultOpExecutioner {
val tadShape = new CudaLongDataBuffer(pack.primaryShapeInfo(), pack.specialShapeInfo(), pack.shapeInfoLength());
val tadOffsets = new CudaLongDataBuffer(pack.primaryOffsets(), pack.specialOffsets(), pack.numberOfTads());
nativeOps.deleteTadPack(pack);
return new TadPack(tadShape, tadOffsets);
}

View File

@ -3048,6 +3048,7 @@ public static class NativeOps extends org.nd4j.nativeblas.NativeOps {
public native ConstantDataBuffer constantBuffer(@Cast("nd4j::DataType") int dtype, ConstantDescriptor descriptor);
public native void deleteShapeBuffer(@Cast("Nd4jPointer") Pointer ptr);
public native void deleteTadPack(@Cast("Nd4jPointer") Pointer ptr);
public native @Cast("char*") String runLightBenchmarkSuit(@Cast("bool") boolean printOut);
public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean printOut);

View File

@ -2176,6 +2176,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
val tadShape = new LongBuffer(pack.primaryShapeInfo(), pack.shapeInfoLength());
val tadOffsets = new LongBuffer(pack.primaryOffsets(), pack.numberOfTads());
loop.deleteTadPack(pack);
return new TadPack(tadShape, tadOffsets);
}

View File

@ -3048,6 +3048,7 @@ public static class NativeOps extends org.nd4j.nativeblas.NativeOps {
public native ConstantDataBuffer constantBuffer(@Cast("nd4j::DataType") int dtype, ConstantDescriptor descriptor);
public native void deleteShapeBuffer(@Cast("Nd4jPointer") Pointer ptr);
public native void deleteTadPack(@Cast("Nd4jPointer") Pointer ptr);
public native @Cast("char*") String runLightBenchmarkSuit(@Cast("bool") boolean printOut);
public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean printOut);