OpContext handling (#214)

* nano tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* OpContext tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* OpContext deallocators

Signed-off-by: raver119 <raver119@gmail.com>

* get rid of few mkldnn safety checks

Signed-off-by: raver119 <raver119@gmail.com>

* databuffer setSpecial fix

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2020-02-05 07:27:24 +03:00 committed by GitHub
parent f6b3032def
commit 5d28e6143d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 229 additions and 50 deletions

View File

@ -147,8 +147,7 @@ public class MKLDNNBatchNormHelper implements BatchNormalizationHelper {
} }
//Note: batchnorm op expects rank 1 inputs for mean/var etc, not rank 2 shape [1,x] //Note: batchnorm op expects rank 1 inputs for mean/var etc, not rank 2 shape [1,x]
context.getInputArrays().clear(); context.purge();
context.getOutputArrays().clear();
context.setInputArray(0, x); context.setInputArray(0, x);
context.setInputArray(1, m); context.setInputArray(1, m);
context.setInputArray(2, v); context.setInputArray(2, v);

View File

@ -89,8 +89,7 @@ public class MKLDNNConvHelper implements ConvolutionHelper {
INDArray[] inputsArr = biasGradView == null ? new INDArray[]{input, weightsPermute, delta} : new INDArray[]{input, weightsPermute, bias, delta}; INDArray[] inputsArr = biasGradView == null ? new INDArray[]{input, weightsPermute, delta} : new INDArray[]{input, weightsPermute, bias, delta};
INDArray[] outputArr = biasGradView == null ? new INDArray[]{gradAtInput, weightGradViewPermute} : new INDArray[]{gradAtInput, weightGradViewPermute, biasGradView}; INDArray[] outputArr = biasGradView == null ? new INDArray[]{gradAtInput, weightGradViewPermute} : new INDArray[]{gradAtInput, weightGradViewPermute, biasGradView};
contextBwd.getInputArrays().clear(); contextBwd.purge();
contextBwd.getOutputArrays().clear();
for( int i=0; i<inputsArr.length; i++ ){ for( int i=0; i<inputsArr.length; i++ ){
contextBwd.setInputArray(i, inputsArr[i]); contextBwd.setInputArray(i, inputsArr[i]);
} }
@ -100,8 +99,6 @@ public class MKLDNNConvHelper implements ConvolutionHelper {
Conv2DDerivative op = new Conv2DDerivative(); Conv2DDerivative op = new Conv2DDerivative();
Nd4j.exec(op, contextBwd); Nd4j.exec(op, contextBwd);
contextBwd.getInputArrays().clear();
contextBwd.getOutputArrays().clear();
Gradient g = new DefaultGradient(); Gradient g = new DefaultGradient();
if(biasGradView != null) { if(biasGradView != null) {
@ -145,16 +142,14 @@ public class MKLDNNConvHelper implements ConvolutionHelper {
weights = weights.permute(2,3,1,0); weights = weights.permute(2,3,1,0);
INDArray[] inputsArr = bias == null ? new INDArray[]{input, weights} : new INDArray[]{input, weights, bias}; INDArray[] inputsArr = bias == null ? new INDArray[]{input, weights} : new INDArray[]{input, weights, bias};
context.getInputArrays().clear(); context.purge();
for( int i=0; i<inputsArr.length; i++ ){ for( int i=0; i<inputsArr.length; i++ ){
context.setInputArray(i, inputsArr[i]); context.setInputArray(i, inputsArr[i]);
} }
context.getOutputArrays().clear();
context.setOutputArray(0, out); context.setOutputArray(0, out);
Conv2D op = new Conv2D(); Conv2D op = new Conv2D();
Nd4j.exec(op, context); Nd4j.exec(op, context);
context.getInputArrays().clear();
context.getOutputArrays().clear();
return out; return out;
} }

View File

@ -59,7 +59,8 @@ public class MKLDNNLocalResponseNormalizationHelper extends BaseMKLDNNHelper imp
context = Nd4j.getExecutioner().buildContext(); context = Nd4j.getExecutioner().buildContext();
context.setTArguments(k, alpha, beta); context.setTArguments(k, alpha, beta);
context.setIArguments((int)n); context.setIArguments((int)n);
} } else
context.purge();
LocalResponseNormalization op = new LocalResponseNormalization(); LocalResponseNormalization op = new LocalResponseNormalization();
@ -80,7 +81,8 @@ public class MKLDNNLocalResponseNormalizationHelper extends BaseMKLDNNHelper imp
context = Nd4j.getExecutioner().buildContext(); context = Nd4j.getExecutioner().buildContext();
context.setTArguments(k, alpha, beta); context.setTArguments(k, alpha, beta);
context.setIArguments((int)n); context.setIArguments((int)n);
} } else
context.purge();
context.setInputArray(0, x); context.setInputArray(0, x);
context.setOutputArray(0, out); context.setOutputArray(0, out);

View File

@ -132,13 +132,12 @@ public class MKLDNNSubsamplingHelper implements SubsamplingHelper {
return null; return null;
} }
context.getInputArrays().clear(); context.purge();
context.getOutputArrays().clear();
context.setInputArray(0, input); context.setInputArray(0, input);
context.setOutputArray(0, output); context.setOutputArray(0, output);
Nd4j.exec(op, context); Nd4j.exec(op, context);
return output; return output;
} }

View File

@ -1601,6 +1601,7 @@ ND4J_EXPORT OpaqueRandomGenerator* getGraphContextRandomGenerator(OpaqueContext*
ND4J_EXPORT void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow); ND4J_EXPORT void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow);
ND4J_EXPORT void ctxShapeFunctionOverride(OpaqueContext* ptr, bool reallyOverride); ND4J_EXPORT void ctxShapeFunctionOverride(OpaqueContext* ptr, bool reallyOverride);
ND4J_EXPORT void ctxSetExecutionMode(OpaqueContext* ptr, int execMode); ND4J_EXPORT void ctxSetExecutionMode(OpaqueContext* ptr, int execMode);
ND4J_EXPORT void ctxPurge(OpaqueContext* ptr);
ND4J_EXPORT void markGraphContextInplace(OpaqueContext* ptr, bool reallyInplace); ND4J_EXPORT void markGraphContextInplace(OpaqueContext* ptr, bool reallyInplace);
ND4J_EXPORT void setGraphContextCudaContext(OpaqueContext* ptr, void *stream, void *reductionPointer, void *allocationPointer); ND4J_EXPORT void setGraphContextCudaContext(OpaqueContext* ptr, void *stream, void *reductionPointer, void *allocationPointer);
ND4J_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo); ND4J_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);

View File

@ -2815,6 +2815,10 @@ void ctxSetExecutionMode(OpaqueContext* ptr, int execMode) {
ptr->setExecutionMode((samediff::ExecutionMode) execMode); ptr->setExecutionMode((samediff::ExecutionMode) execMode);
} }
void ctxPurge(OpaqueContext* ptr) {
ptr->clearFastPath();
}
nd4j::graph::RandomGenerator* createRandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) { nd4j::graph::RandomGenerator* createRandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) {
return new nd4j::graph::RandomGenerator(rootSeed, nodeSeed); return new nd4j::graph::RandomGenerator(rootSeed, nodeSeed);
} }

View File

@ -3771,6 +3771,10 @@ void ctxShapeFunctionOverride(OpaqueContext* ptr, bool reallyOverride) {
ptr->setShapeFunctionOverride(reallyOverride); ptr->setShapeFunctionOverride(reallyOverride);
} }
void ctxPurge(OpaqueContext* ptr) {
ptr->clearFastPath();
}
int binaryLevel() { int binaryLevel() {
return 0; return 0;
} }

View File

@ -305,12 +305,17 @@ namespace nd4j {
if (_primaryBuffer != nullptr && _isOwnerPrimary) { if (_primaryBuffer != nullptr && _isOwnerPrimary) {
deletePrimary(); deletePrimary();
} }
_primaryBuffer = buffer; _primaryBuffer = buffer;
_isOwnerPrimary = false; _isOwnerPrimary = false;
_lenInBytes = length * DataTypeUtils::sizeOf(_dataType); _lenInBytes = length * DataTypeUtils::sizeOf(_dataType);
} }
void DataBuffer::setSpecialBuffer(void *buffer, size_t length) { void DataBuffer::setSpecialBuffer(void *buffer, size_t length) {
if (_specialBuffer != nullptr && _isOwnerSpecial) {
deleteSpecial();
}
this->setSpecial(buffer, false); this->setSpecial(buffer, false);
_lenInBytes = length * DataTypeUtils::sizeOf(_dataType); _lenInBytes = length * DataTypeUtils::sizeOf(_dataType);
} }

View File

@ -204,6 +204,13 @@ namespace nd4j {
void setBArguments(const std::vector<bool> &tArgs); void setBArguments(const std::vector<bool> &tArgs);
void setDArguments(const std::vector<nd4j::DataType> &dArgs); void setDArguments(const std::vector<nd4j::DataType> &dArgs);
/**
* This method purges fastpath in/out contents and releases all the handles.
*
* PLEASE NOTE: I/T/B/D args will stay intact
*/
void clearFastPath();
void setCudaContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer allocationPointer); void setCudaContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer allocationPointer);
void allowHelpers(bool reallyAllow); void allowHelpers(bool reallyAllow);

View File

@ -563,6 +563,16 @@ namespace nd4j {
for (auto d:dArgs) for (auto d:dArgs)
_dArgs.emplace_back(d); _dArgs.emplace_back(d);
} }
void Context::clearFastPath() {
_fastpath_in.clear();
_fastpath_out.clear();
for (auto v:_handles)
delete v;
_handles.clear();
}
} }
} }

View File

@ -456,10 +456,6 @@ PLATFORM_IMPL(batchnorm, ENGINE_CPU) {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
PLATFORM_CHECK(batchnorm, ENGINE_CPU) { PLATFORM_CHECK(batchnorm, ENGINE_CPU) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
// if (::optimalLevel() < 2)
// return false;
auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw
auto mean = INPUT_VARIABLE(1); // [c] auto mean = INPUT_VARIABLE(1); // [c]
auto variance = INPUT_VARIABLE(2); // [c] auto variance = INPUT_VARIABLE(2); // [c]

View File

@ -265,10 +265,6 @@ PLATFORM_IMPL(conv2d, ENGINE_CPU) {
} }
PLATFORM_CHECK(conv2d, ENGINE_CPU) { PLATFORM_CHECK(conv2d, ENGINE_CPU) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
if (::optimalLevel() < 2)
return false;
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto weights = INPUT_VARIABLE(1); auto weights = INPUT_VARIABLE(1);

View File

@ -270,10 +270,6 @@ PLATFORM_IMPL(conv3dnew, ENGINE_CPU) {
} }
PLATFORM_CHECK(conv3dnew, ENGINE_CPU) { PLATFORM_CHECK(conv3dnew, ENGINE_CPU) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
if (::optimalLevel() < 2)
return false;
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
@ -335,7 +331,6 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) {
} }
PLATFORM_CHECK(conv3dnew_bp, ENGINE_CPU) { PLATFORM_CHECK(conv3dnew_bp, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]

View File

@ -407,10 +407,6 @@ PLATFORM_IMPL(deconv2d, ENGINE_CPU) {
} }
PLATFORM_CHECK(deconv2d, ENGINE_CPU) { PLATFORM_CHECK(deconv2d, ENGINE_CPU) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
// if (::optimalLevel() < 2)
// return false;
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto weights = INPUT_VARIABLE(1); auto weights = INPUT_VARIABLE(1);
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;

View File

@ -422,10 +422,6 @@ PLATFORM_IMPL(deconv3d, ENGINE_CPU) {
} }
PLATFORM_CHECK(deconv3d, ENGINE_CPU) { PLATFORM_CHECK(deconv3d, ENGINE_CPU) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
// if (::optimalLevel() < 2)
// return false;
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto weights = INPUT_VARIABLE(1); auto weights = INPUT_VARIABLE(1);
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;

View File

@ -401,10 +401,6 @@ PLATFORM_IMPL(depthwise_conv2d, ENGINE_CPU) {
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
PLATFORM_CHECK(depthwise_conv2d, ENGINE_CPU) { PLATFORM_CHECK(depthwise_conv2d, ENGINE_CPU) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
if (::optimalLevel() < 2)
return false;
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto weights = INPUT_VARIABLE(1); auto weights = INPUT_VARIABLE(1);
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;
@ -477,7 +473,6 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CPU) {
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
PLATFORM_CHECK(depthwise_conv2d_bp, ENGINE_CPU) { PLATFORM_CHECK(depthwise_conv2d_bp, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC]

View File

@ -43,7 +43,7 @@ public class DeallocatorService {
private Map<String, DeallocatableReference> referenceMap = new ConcurrentHashMap<>(); private Map<String, DeallocatableReference> referenceMap = new ConcurrentHashMap<>();
private List<List<ReferenceQueue<Deallocatable>>> deviceMap = new ArrayList<>(); private List<List<ReferenceQueue<Deallocatable>>> deviceMap = new ArrayList<>();
private AtomicLong counter = new AtomicLong(0); private final transient AtomicLong counter = new AtomicLong(0);
public DeallocatorService() { public DeallocatorService() {
// we need to have at least 2 threads, but for CUDA we'd need at least numDevices threads, due to thread->device affinity // we need to have at least 2 threads, but for CUDA we'd need at least numDevices threads, due to thread->device affinity

View File

@ -153,4 +153,10 @@ public abstract class BaseOpContext implements OpContext {
for (int e = 0; e < arrays.length; e++) for (int e = 0; e < arrays.length; e++)
setOutputArray(e, arrays[e]); setOutputArray(e, arrays[e]);
} }
@Override
public void purge() {
fastpath_in.clear();
fastpath_out.clear();
}
} }

View File

@ -162,4 +162,9 @@ public interface OpContext extends AutoCloseable {
* @param mode * @param mode
*/ */
void setExecutionMode(ExecutionMode mode); void setExecutionMode(ExecutionMode mode);
/**
* This method removes all in/out arrays from this OpContext
*/
void purge();
} }

View File

@ -1161,6 +1161,7 @@ public interface NativeOps {
void ctxAllowHelpers(OpaqueContext ptr, boolean reallyAllow); void ctxAllowHelpers(OpaqueContext ptr, boolean reallyAllow);
void ctxSetExecutionMode(OpaqueContext ptr, int execMode); void ctxSetExecutionMode(OpaqueContext ptr, int execMode);
void ctxShapeFunctionOverride(OpaqueContext ptr, boolean reallyOverride); void ctxShapeFunctionOverride(OpaqueContext ptr, boolean reallyOverride);
void ctxPurge(OpaqueContext ptr);
void deleteGraphContext(OpaqueContext ptr); void deleteGraphContext(OpaqueContext ptr);
OpaqueRandomGenerator createRandomGenerator(long rootSeed, long nodeSeed); OpaqueRandomGenerator createRandomGenerator(long rootSeed, long nodeSeed);

View File

@ -23,6 +23,8 @@ import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t; import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.memory.Deallocatable;
import org.nd4j.linalg.api.memory.Deallocator;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseOpContext; import org.nd4j.linalg.api.ops.BaseOpContext;
import org.nd4j.linalg.api.ops.ExecutionMode; import org.nd4j.linalg.api.ops.ExecutionMode;
@ -40,14 +42,19 @@ import org.nd4j.nativeblas.OpaqueRandomGenerator;
* CUDA wrapper for op Context * CUDA wrapper for op Context
* @author raver119@gmail.com * @author raver119@gmail.com
*/ */
public class CudaOpContext extends BaseOpContext implements OpContext { public class CudaOpContext extends BaseOpContext implements OpContext, Deallocatable {
// we might want to have configurable // we might want to have configurable
private NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); private NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
private OpaqueContext context = nativeOps.createGraphContext(1); private OpaqueContext context = nativeOps.createGraphContext(1);
private final transient long id = Nd4j.getDeallocatorService().nextValue();
public CudaOpContext() {
Nd4j.getDeallocatorService().pickObject(this);
}
@Override @Override
public void close() { public void close() {
nativeOps.deleteGraphContext(context); // no-op
} }
@Override @Override
@ -143,4 +150,25 @@ public class CudaOpContext extends BaseOpContext implements OpContext {
super.setExecutionMode(mode); super.setExecutionMode(mode);
nativeOps.ctxSetExecutionMode(context, mode.ordinal()); nativeOps.ctxSetExecutionMode(context, mode.ordinal());
} }
@Override
public void purge() {
super.purge();
nativeOps.ctxPurge(context);
}
@Override
public String getUniqueId() {
return new String("CTX_" + id);
}
@Override
public Deallocator deallocator() {
return new CudaOpContextDeallocator(this);
}
@Override
public int targetDevice() {
return 0;
}
} }

View File

@ -0,0 +1,34 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.jcublas.ops.executioner;
import org.nd4j.linalg.api.memory.Deallocator;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.nativeblas.OpaqueContext;
public class CudaOpContextDeallocator implements Deallocator {
private transient final OpaqueContext context;
public CudaOpContextDeallocator(CudaOpContext ctx) {
context = (OpaqueContext) ctx.contextPointer();
}
@Override
public void deallocate() {
NativeOpsHolder.getInstance().getDeviceNativeOps().deleteGraphContext(context);
}
}

View File

@ -3090,6 +3090,7 @@ public native OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext
public native void ctxAllowHelpers(OpaqueContext ptr, @Cast("bool") boolean reallyAllow); public native void ctxAllowHelpers(OpaqueContext ptr, @Cast("bool") boolean reallyAllow);
public native void ctxShapeFunctionOverride(OpaqueContext ptr, @Cast("bool") boolean reallyOverride); public native void ctxShapeFunctionOverride(OpaqueContext ptr, @Cast("bool") boolean reallyOverride);
public native void ctxSetExecutionMode(OpaqueContext ptr, int execMode); public native void ctxSetExecutionMode(OpaqueContext ptr, int execMode);
public native void ctxPurge(OpaqueContext ptr);
public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace); public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace);
public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer); public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer);
public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
@ -6453,6 +6454,13 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
public native void setDArguments(@Cast("nd4j::DataType*") @StdVector IntBuffer dArgs); public native void setDArguments(@Cast("nd4j::DataType*") @StdVector IntBuffer dArgs);
public native void setDArguments(@Cast("nd4j::DataType*") @StdVector int[] dArgs); public native void setDArguments(@Cast("nd4j::DataType*") @StdVector int[] dArgs);
/**
* This method purges fastpath in/out contents and releases all the handles.
*
* PLEASE NOTE: I/T/B/D args will stay intact
*/
public native void clearFastPath();
public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer); public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer);
public native void allowHelpers(@Cast("bool") boolean reallyAllow); public native void allowHelpers(@Cast("bool") boolean reallyAllow);

View File

@ -43,7 +43,7 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
protected transient OpaqueDataBuffer ptrDataBuffer; protected transient OpaqueDataBuffer ptrDataBuffer;
private final long instanceId = Nd4j.getDeallocatorService().nextValue(); private transient final long instanceId = Nd4j.getDeallocatorService().nextValue();
protected BaseCpuDataBuffer() { protected BaseCpuDataBuffer() {
@ -52,7 +52,7 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
@Override @Override
public String getUniqueId() { public String getUniqueId() {
return "BCDB_" + instanceId; return new String("BCDB_" + instanceId);
} }
@Override @Override

View File

@ -28,7 +28,7 @@ import org.nd4j.nativeblas.OpaqueDataBuffer;
*/ */
@Slf4j @Slf4j
public class CpuDeallocator implements Deallocator { public class CpuDeallocator implements Deallocator {
private OpaqueDataBuffer opaqueDataBuffer; private final transient OpaqueDataBuffer opaqueDataBuffer;
public CpuDeallocator(BaseCpuDataBuffer buffer) { public CpuDeallocator(BaseCpuDataBuffer buffer) {
opaqueDataBuffer = buffer.getOpaqueDataBuffer(); opaqueDataBuffer = buffer.getOpaqueDataBuffer();

View File

@ -28,6 +28,7 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.pointers.PagedPointer; import org.nd4j.linalg.api.memory.pointers.PagedPointer;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.nativeblas.NativeOpsHolder; import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.nativeblas.OpaqueDataBuffer;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
@ -123,7 +124,7 @@ public class LongBuffer extends BaseCpuDataBuffer {
// we still want this buffer to have native representation // we still want this buffer to have native representation
ptrDataBuffer = NativeOpsHolder.getInstance().getDeviceNativeOps().allocateDataBuffer(0, DataType.INT64.toInt(), false); ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, DataType.INT64, false);
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetPrimaryBuffer(ptrDataBuffer, this.pointer, numberOfElements); NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetPrimaryBuffer(ptrDataBuffer, this.pointer, numberOfElements);
Nd4j.getDeallocatorService().pickObject(this); Nd4j.getDeallocatorService().pickObject(this);

View File

@ -20,11 +20,14 @@ import lombok.NonNull;
import lombok.val; import lombok.val;
import org.bytedeco.javacpp.*; import org.bytedeco.javacpp.*;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.Deallocatable;
import org.nd4j.linalg.api.memory.Deallocator;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseOpContext; import org.nd4j.linalg.api.ops.BaseOpContext;
import org.nd4j.linalg.api.ops.ExecutionMode; import org.nd4j.linalg.api.ops.ExecutionMode;
import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.cpu.nativecpu.buffer.BaseCpuDataBuffer; import org.nd4j.linalg.cpu.nativecpu.buffer.BaseCpuDataBuffer;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.nd4j.nativeblas.NativeOps; import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder; import org.nd4j.nativeblas.NativeOpsHolder;
@ -38,14 +41,19 @@ import java.util.List;
* *
* @author raver119@gmail.com * @author raver119@gmail.com
*/ */
public class CpuOpContext extends BaseOpContext implements OpContext { public class CpuOpContext extends BaseOpContext implements OpContext, Deallocatable {
// we might want to have configurable // we might want to have configurable
private NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); private NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
private OpaqueContext context = nativeOps.createGraphContext(1); private OpaqueContext context = nativeOps.createGraphContext(1);
private final transient long id = Nd4j.getDeallocatorService().nextValue();
public CpuOpContext() {
Nd4j.getDeallocatorService().pickObject(this);
}
@Override @Override
public void close() { public void close() {
nativeOps.deleteGraphContext(context); // no-op
} }
@Override @Override
@ -136,4 +144,25 @@ public class CpuOpContext extends BaseOpContext implements OpContext {
super.setExecutionMode(mode); super.setExecutionMode(mode);
nativeOps.ctxSetExecutionMode(context, mode.ordinal()); nativeOps.ctxSetExecutionMode(context, mode.ordinal());
} }
@Override
public void purge() {
super.purge();
nativeOps.ctxPurge(context);
}
@Override
public String getUniqueId() {
return new String("CTX_" + id);
}
@Override
public Deallocator deallocator() {
return new CpuOpContextDeallocator(this);
}
@Override
public int targetDevice() {
return 0;
}
} }

View File

@ -0,0 +1,34 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.cpu.nativecpu.ops;
import org.nd4j.linalg.api.memory.Deallocator;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.nativeblas.OpaqueContext;
public class CpuOpContextDeallocator implements Deallocator {
private transient final OpaqueContext context;
public CpuOpContextDeallocator(CpuOpContext ctx) {
context = (OpaqueContext) ctx.contextPointer();
}
@Override
public void deallocate() {
NativeOpsHolder.getInstance().getDeviceNativeOps().deleteGraphContext(context);
}
}

View File

@ -3093,6 +3093,7 @@ public native OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext
public native void ctxAllowHelpers(OpaqueContext ptr, @Cast("bool") boolean reallyAllow); public native void ctxAllowHelpers(OpaqueContext ptr, @Cast("bool") boolean reallyAllow);
public native void ctxShapeFunctionOverride(OpaqueContext ptr, @Cast("bool") boolean reallyOverride); public native void ctxShapeFunctionOverride(OpaqueContext ptr, @Cast("bool") boolean reallyOverride);
public native void ctxSetExecutionMode(OpaqueContext ptr, int execMode); public native void ctxSetExecutionMode(OpaqueContext ptr, int execMode);
public native void ctxPurge(OpaqueContext ptr);
public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace); public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace);
public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer); public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer);
public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
@ -6456,6 +6457,13 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
public native void setDArguments(@Cast("nd4j::DataType*") @StdVector IntBuffer dArgs); public native void setDArguments(@Cast("nd4j::DataType*") @StdVector IntBuffer dArgs);
public native void setDArguments(@Cast("nd4j::DataType*") @StdVector int[] dArgs); public native void setDArguments(@Cast("nd4j::DataType*") @StdVector int[] dArgs);
/**
* This method purges fastpath in/out contents and releases all the handles.
*
* PLEASE NOTE: I/T/B/D args will stay intact
*/
public native void clearFastPath();
public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer); public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer);
public native void allowHelpers(@Cast("bool") boolean reallyAllow); public native void allowHelpers(@Cast("bool") boolean reallyAllow);

View File

@ -8262,6 +8262,31 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertArrayEquals(new long[]{10, 0}, out2.shape()); assertArrayEquals(new long[]{10, 0}, out2.shape());
} }
@Test
public void testDealloc_1() throws Exception {
for (int e = 0; e < 5000; e++){
try(val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace("someid")) {
val x = Nd4j.createUninitialized(DataType.FLOAT, 1, 1000);
//val y = x.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 100)).reshape('c', 10, 10);
//val z = x.get(NDArrayIndex.point(0), NDArrayIndex.interval(100, 200)).reshape('c', 10, 10);
//val a = x.get(NDArrayIndex.point(0), NDArrayIndex.interval(200, 300)).reshape('f', 10, 10);
} finally {
//System.gc();
}
}
Thread.sleep(1000);
System.gc();
Thread.sleep(1000);
System.gc();
System.gc();
System.gc();
//Nd4j.getMemoryManager().printRemainingStacks();
}
@Override @Override
public char ordering() { public char ordering() {
return 'c'; return 'c';