OpContext handling (#214)
* nano tweaks Signed-off-by: raver119 <raver119@gmail.com> * OpContext tweaks Signed-off-by: raver119 <raver119@gmail.com> * OpContext deallocators Signed-off-by: raver119 <raver119@gmail.com> * get rid of few mkldnn safety checks Signed-off-by: raver119 <raver119@gmail.com> * databuffer setSpecial fix Signed-off-by: raver119 <raver119@gmail.com>master
parent
f6b3032def
commit
5d28e6143d
|
@ -147,8 +147,7 @@ public class MKLDNNBatchNormHelper implements BatchNormalizationHelper {
|
|||
}
|
||||
|
||||
//Note: batchnorm op expects rank 1 inputs for mean/var etc, not rank 2 shape [1,x]
|
||||
context.getInputArrays().clear();
|
||||
context.getOutputArrays().clear();
|
||||
context.purge();
|
||||
context.setInputArray(0, x);
|
||||
context.setInputArray(1, m);
|
||||
context.setInputArray(2, v);
|
||||
|
|
|
@ -89,8 +89,7 @@ public class MKLDNNConvHelper implements ConvolutionHelper {
|
|||
|
||||
INDArray[] inputsArr = biasGradView == null ? new INDArray[]{input, weightsPermute, delta} : new INDArray[]{input, weightsPermute, bias, delta};
|
||||
INDArray[] outputArr = biasGradView == null ? new INDArray[]{gradAtInput, weightGradViewPermute} : new INDArray[]{gradAtInput, weightGradViewPermute, biasGradView};
|
||||
contextBwd.getInputArrays().clear();
|
||||
contextBwd.getOutputArrays().clear();
|
||||
contextBwd.purge();
|
||||
for( int i=0; i<inputsArr.length; i++ ){
|
||||
contextBwd.setInputArray(i, inputsArr[i]);
|
||||
}
|
||||
|
@ -100,8 +99,6 @@ public class MKLDNNConvHelper implements ConvolutionHelper {
|
|||
|
||||
Conv2DDerivative op = new Conv2DDerivative();
|
||||
Nd4j.exec(op, contextBwd);
|
||||
contextBwd.getInputArrays().clear();
|
||||
contextBwd.getOutputArrays().clear();
|
||||
|
||||
Gradient g = new DefaultGradient();
|
||||
if(biasGradView != null) {
|
||||
|
@ -145,16 +142,14 @@ public class MKLDNNConvHelper implements ConvolutionHelper {
|
|||
weights = weights.permute(2,3,1,0);
|
||||
|
||||
INDArray[] inputsArr = bias == null ? new INDArray[]{input, weights} : new INDArray[]{input, weights, bias};
|
||||
context.getInputArrays().clear();
|
||||
context.purge();
|
||||
for( int i=0; i<inputsArr.length; i++ ){
|
||||
context.setInputArray(i, inputsArr[i]);
|
||||
}
|
||||
context.getOutputArrays().clear();
|
||||
|
||||
context.setOutputArray(0, out);
|
||||
Conv2D op = new Conv2D();
|
||||
Nd4j.exec(op, context);
|
||||
context.getInputArrays().clear();
|
||||
context.getOutputArrays().clear();
|
||||
|
||||
return out;
|
||||
}
|
||||
|
|
|
@ -59,7 +59,8 @@ public class MKLDNNLocalResponseNormalizationHelper extends BaseMKLDNNHelper imp
|
|||
context = Nd4j.getExecutioner().buildContext();
|
||||
context.setTArguments(k, alpha, beta);
|
||||
context.setIArguments((int)n);
|
||||
}
|
||||
} else
|
||||
context.purge();
|
||||
|
||||
LocalResponseNormalization op = new LocalResponseNormalization();
|
||||
|
||||
|
@ -80,7 +81,8 @@ public class MKLDNNLocalResponseNormalizationHelper extends BaseMKLDNNHelper imp
|
|||
context = Nd4j.getExecutioner().buildContext();
|
||||
context.setTArguments(k, alpha, beta);
|
||||
context.setIArguments((int)n);
|
||||
}
|
||||
} else
|
||||
context.purge();
|
||||
|
||||
context.setInputArray(0, x);
|
||||
context.setOutputArray(0, out);
|
||||
|
|
|
@ -132,13 +132,12 @@ public class MKLDNNSubsamplingHelper implements SubsamplingHelper {
|
|||
return null;
|
||||
}
|
||||
|
||||
context.getInputArrays().clear();
|
||||
context.getOutputArrays().clear();
|
||||
|
||||
context.purge();
|
||||
context.setInputArray(0, input);
|
||||
context.setOutputArray(0, output);
|
||||
|
||||
Nd4j.exec(op, context);
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
|
|
|
@ -1601,6 +1601,7 @@ ND4J_EXPORT OpaqueRandomGenerator* getGraphContextRandomGenerator(OpaqueContext*
|
|||
ND4J_EXPORT void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow);
|
||||
ND4J_EXPORT void ctxShapeFunctionOverride(OpaqueContext* ptr, bool reallyOverride);
|
||||
ND4J_EXPORT void ctxSetExecutionMode(OpaqueContext* ptr, int execMode);
|
||||
ND4J_EXPORT void ctxPurge(OpaqueContext* ptr);
|
||||
ND4J_EXPORT void markGraphContextInplace(OpaqueContext* ptr, bool reallyInplace);
|
||||
ND4J_EXPORT void setGraphContextCudaContext(OpaqueContext* ptr, void *stream, void *reductionPointer, void *allocationPointer);
|
||||
ND4J_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
|
||||
|
|
|
@ -2815,6 +2815,10 @@ void ctxSetExecutionMode(OpaqueContext* ptr, int execMode) {
|
|||
ptr->setExecutionMode((samediff::ExecutionMode) execMode);
|
||||
}
|
||||
|
||||
void ctxPurge(OpaqueContext* ptr) {
|
||||
ptr->clearFastPath();
|
||||
}
|
||||
|
||||
nd4j::graph::RandomGenerator* createRandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) {
|
||||
return new nd4j::graph::RandomGenerator(rootSeed, nodeSeed);
|
||||
}
|
||||
|
|
|
@ -3771,6 +3771,10 @@ void ctxShapeFunctionOverride(OpaqueContext* ptr, bool reallyOverride) {
|
|||
ptr->setShapeFunctionOverride(reallyOverride);
|
||||
}
|
||||
|
||||
void ctxPurge(OpaqueContext* ptr) {
|
||||
ptr->clearFastPath();
|
||||
}
|
||||
|
||||
int binaryLevel() {
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -305,12 +305,17 @@ namespace nd4j {
|
|||
if (_primaryBuffer != nullptr && _isOwnerPrimary) {
|
||||
deletePrimary();
|
||||
}
|
||||
|
||||
_primaryBuffer = buffer;
|
||||
_isOwnerPrimary = false;
|
||||
_lenInBytes = length * DataTypeUtils::sizeOf(_dataType);
|
||||
}
|
||||
|
||||
void DataBuffer::setSpecialBuffer(void *buffer, size_t length) {
|
||||
if (_specialBuffer != nullptr && _isOwnerSpecial) {
|
||||
deleteSpecial();
|
||||
}
|
||||
|
||||
this->setSpecial(buffer, false);
|
||||
_lenInBytes = length * DataTypeUtils::sizeOf(_dataType);
|
||||
}
|
||||
|
|
|
@ -204,6 +204,13 @@ namespace nd4j {
|
|||
void setBArguments(const std::vector<bool> &tArgs);
|
||||
void setDArguments(const std::vector<nd4j::DataType> &dArgs);
|
||||
|
||||
/**
|
||||
* This method purges fastpath in/out contents and releases all the handles.
|
||||
*
|
||||
* PLEASE NOTE: I/T/B/D args will stay intact
|
||||
*/
|
||||
void clearFastPath();
|
||||
|
||||
void setCudaContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer allocationPointer);
|
||||
|
||||
void allowHelpers(bool reallyAllow);
|
||||
|
|
|
@ -563,6 +563,16 @@ namespace nd4j {
|
|||
for (auto d:dArgs)
|
||||
_dArgs.emplace_back(d);
|
||||
}
|
||||
|
||||
void Context::clearFastPath() {
|
||||
_fastpath_in.clear();
|
||||
_fastpath_out.clear();
|
||||
|
||||
for (auto v:_handles)
|
||||
delete v;
|
||||
|
||||
_handles.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -456,10 +456,6 @@ PLATFORM_IMPL(batchnorm, ENGINE_CPU) {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_CHECK(batchnorm, ENGINE_CPU) {
|
||||
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||
// if (::optimalLevel() < 2)
|
||||
// return false;
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw
|
||||
auto mean = INPUT_VARIABLE(1); // [c]
|
||||
auto variance = INPUT_VARIABLE(2); // [c]
|
||||
|
|
|
@ -265,10 +265,6 @@ PLATFORM_IMPL(conv2d, ENGINE_CPU) {
|
|||
}
|
||||
|
||||
PLATFORM_CHECK(conv2d, ENGINE_CPU) {
|
||||
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||
if (::optimalLevel() < 2)
|
||||
return false;
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto weights = INPUT_VARIABLE(1);
|
||||
|
||||
|
|
|
@ -270,10 +270,6 @@ PLATFORM_IMPL(conv3dnew, ENGINE_CPU) {
|
|||
}
|
||||
|
||||
PLATFORM_CHECK(conv3dnew, ENGINE_CPU) {
|
||||
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||
if (::optimalLevel() < 2)
|
||||
return false;
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
@ -335,7 +331,6 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) {
|
|||
}
|
||||
|
||||
PLATFORM_CHECK(conv3dnew_bp, ENGINE_CPU) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
|
|
@ -407,10 +407,6 @@ PLATFORM_IMPL(deconv2d, ENGINE_CPU) {
|
|||
}
|
||||
|
||||
PLATFORM_CHECK(deconv2d, ENGINE_CPU) {
|
||||
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||
// if (::optimalLevel() < 2)
|
||||
// return false;
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto weights = INPUT_VARIABLE(1);
|
||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;
|
||||
|
|
|
@ -422,10 +422,6 @@ PLATFORM_IMPL(deconv3d, ENGINE_CPU) {
|
|||
}
|
||||
|
||||
PLATFORM_CHECK(deconv3d, ENGINE_CPU) {
|
||||
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||
// if (::optimalLevel() < 2)
|
||||
// return false;
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto weights = INPUT_VARIABLE(1);
|
||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;
|
||||
|
|
|
@ -401,10 +401,6 @@ PLATFORM_IMPL(depthwise_conv2d, ENGINE_CPU) {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_CHECK(depthwise_conv2d, ENGINE_CPU) {
|
||||
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||
if (::optimalLevel() < 2)
|
||||
return false;
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto weights = INPUT_VARIABLE(1);
|
||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;
|
||||
|
@ -477,7 +473,6 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CPU) {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_CHECK(depthwise_conv2d_bp, ENGINE_CPU) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always
|
||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC]
|
||||
|
|
|
@ -43,7 +43,7 @@ public class DeallocatorService {
|
|||
private Map<String, DeallocatableReference> referenceMap = new ConcurrentHashMap<>();
|
||||
private List<List<ReferenceQueue<Deallocatable>>> deviceMap = new ArrayList<>();
|
||||
|
||||
private AtomicLong counter = new AtomicLong(0);
|
||||
private final transient AtomicLong counter = new AtomicLong(0);
|
||||
|
||||
public DeallocatorService() {
|
||||
// we need to have at least 2 threads, but for CUDA we'd need at least numDevices threads, due to thread->device affinity
|
||||
|
|
|
@ -153,4 +153,10 @@ public abstract class BaseOpContext implements OpContext {
|
|||
for (int e = 0; e < arrays.length; e++)
|
||||
setOutputArray(e, arrays[e]);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void purge() {
|
||||
fastpath_in.clear();
|
||||
fastpath_out.clear();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -162,4 +162,9 @@ public interface OpContext extends AutoCloseable {
|
|||
* @param mode
|
||||
*/
|
||||
void setExecutionMode(ExecutionMode mode);
|
||||
|
||||
/**
|
||||
* This method removes all in/out arrays from this OpContext
|
||||
*/
|
||||
void purge();
|
||||
}
|
||||
|
|
|
@ -1161,6 +1161,7 @@ public interface NativeOps {
|
|||
void ctxAllowHelpers(OpaqueContext ptr, boolean reallyAllow);
|
||||
void ctxSetExecutionMode(OpaqueContext ptr, int execMode);
|
||||
void ctxShapeFunctionOverride(OpaqueContext ptr, boolean reallyOverride);
|
||||
void ctxPurge(OpaqueContext ptr);
|
||||
void deleteGraphContext(OpaqueContext ptr);
|
||||
|
||||
OpaqueRandomGenerator createRandomGenerator(long rootSeed, long nodeSeed);
|
||||
|
|
|
@ -23,6 +23,8 @@ import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
|||
import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.concurrency.AffinityManager;
|
||||
import org.nd4j.linalg.api.memory.Deallocatable;
|
||||
import org.nd4j.linalg.api.memory.Deallocator;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.BaseOpContext;
|
||||
import org.nd4j.linalg.api.ops.ExecutionMode;
|
||||
|
@ -40,14 +42,19 @@ import org.nd4j.nativeblas.OpaqueRandomGenerator;
|
|||
* CUDA wrapper for op Context
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class CudaOpContext extends BaseOpContext implements OpContext {
|
||||
public class CudaOpContext extends BaseOpContext implements OpContext, Deallocatable {
|
||||
// we might want to have configurable
|
||||
private NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
|
||||
private OpaqueContext context = nativeOps.createGraphContext(1);
|
||||
private final transient long id = Nd4j.getDeallocatorService().nextValue();
|
||||
|
||||
public CudaOpContext() {
|
||||
Nd4j.getDeallocatorService().pickObject(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
nativeOps.deleteGraphContext(context);
|
||||
// no-op
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -143,4 +150,25 @@ public class CudaOpContext extends BaseOpContext implements OpContext {
|
|||
super.setExecutionMode(mode);
|
||||
nativeOps.ctxSetExecutionMode(context, mode.ordinal());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void purge() {
|
||||
super.purge();
|
||||
nativeOps.ctxPurge(context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getUniqueId() {
|
||||
return new String("CTX_" + id);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Deallocator deallocator() {
|
||||
return new CudaOpContextDeallocator(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int targetDevice() {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.jcublas.ops.executioner;
|
||||
|
||||
import org.nd4j.linalg.api.memory.Deallocator;
|
||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||
import org.nd4j.nativeblas.OpaqueContext;
|
||||
|
||||
public class CudaOpContextDeallocator implements Deallocator {
|
||||
private transient final OpaqueContext context;
|
||||
|
||||
public CudaOpContextDeallocator(CudaOpContext ctx) {
|
||||
context = (OpaqueContext) ctx.contextPointer();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deallocate() {
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().deleteGraphContext(context);
|
||||
}
|
||||
}
|
|
@ -3090,6 +3090,7 @@ public native OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext
|
|||
public native void ctxAllowHelpers(OpaqueContext ptr, @Cast("bool") boolean reallyAllow);
|
||||
public native void ctxShapeFunctionOverride(OpaqueContext ptr, @Cast("bool") boolean reallyOverride);
|
||||
public native void ctxSetExecutionMode(OpaqueContext ptr, int execMode);
|
||||
public native void ctxPurge(OpaqueContext ptr);
|
||||
public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace);
|
||||
public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer);
|
||||
public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
|
||||
|
@ -6453,6 +6454,13 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
|
|||
public native void setDArguments(@Cast("nd4j::DataType*") @StdVector IntBuffer dArgs);
|
||||
public native void setDArguments(@Cast("nd4j::DataType*") @StdVector int[] dArgs);
|
||||
|
||||
/**
|
||||
* This method purges fastpath in/out contents and releases all the handles.
|
||||
*
|
||||
* PLEASE NOTE: I/T/B/D args will stay intact
|
||||
*/
|
||||
public native void clearFastPath();
|
||||
|
||||
public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer);
|
||||
|
||||
public native void allowHelpers(@Cast("bool") boolean reallyAllow);
|
||||
|
|
|
@ -43,7 +43,7 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
|
|||
|
||||
protected transient OpaqueDataBuffer ptrDataBuffer;
|
||||
|
||||
private final long instanceId = Nd4j.getDeallocatorService().nextValue();
|
||||
private transient final long instanceId = Nd4j.getDeallocatorService().nextValue();
|
||||
|
||||
protected BaseCpuDataBuffer() {
|
||||
|
||||
|
@ -52,7 +52,7 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
|
|||
|
||||
@Override
|
||||
public String getUniqueId() {
|
||||
return "BCDB_" + instanceId;
|
||||
return new String("BCDB_" + instanceId);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -28,7 +28,7 @@ import org.nd4j.nativeblas.OpaqueDataBuffer;
|
|||
*/
|
||||
@Slf4j
|
||||
public class CpuDeallocator implements Deallocator {
|
||||
private OpaqueDataBuffer opaqueDataBuffer;
|
||||
private final transient OpaqueDataBuffer opaqueDataBuffer;
|
||||
|
||||
public CpuDeallocator(BaseCpuDataBuffer buffer) {
|
||||
opaqueDataBuffer = buffer.getOpaqueDataBuffer();
|
||||
|
|
|
@ -28,6 +28,7 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
|||
import org.nd4j.linalg.api.memory.pointers.PagedPointer;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||
import org.nd4j.nativeblas.OpaqueDataBuffer;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
|
@ -123,7 +124,7 @@ public class LongBuffer extends BaseCpuDataBuffer {
|
|||
|
||||
// we still want this buffer to have native representation
|
||||
|
||||
ptrDataBuffer = NativeOpsHolder.getInstance().getDeviceNativeOps().allocateDataBuffer(0, DataType.INT64.toInt(), false);
|
||||
ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, DataType.INT64, false);
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetPrimaryBuffer(ptrDataBuffer, this.pointer, numberOfElements);
|
||||
|
||||
Nd4j.getDeallocatorService().pickObject(this);
|
||||
|
|
|
@ -20,11 +20,14 @@ import lombok.NonNull;
|
|||
import lombok.val;
|
||||
import org.bytedeco.javacpp.*;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.memory.Deallocatable;
|
||||
import org.nd4j.linalg.api.memory.Deallocator;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.BaseOpContext;
|
||||
import org.nd4j.linalg.api.ops.ExecutionMode;
|
||||
import org.nd4j.linalg.api.ops.OpContext;
|
||||
import org.nd4j.linalg.cpu.nativecpu.buffer.BaseCpuDataBuffer;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.nativeblas.NativeOps;
|
||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||
|
@ -38,14 +41,19 @@ import java.util.List;
|
|||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class CpuOpContext extends BaseOpContext implements OpContext {
|
||||
public class CpuOpContext extends BaseOpContext implements OpContext, Deallocatable {
|
||||
// we might want to have configurable
|
||||
private NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
|
||||
private OpaqueContext context = nativeOps.createGraphContext(1);
|
||||
private final transient long id = Nd4j.getDeallocatorService().nextValue();
|
||||
|
||||
public CpuOpContext() {
|
||||
Nd4j.getDeallocatorService().pickObject(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
nativeOps.deleteGraphContext(context);
|
||||
// no-op
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -136,4 +144,25 @@ public class CpuOpContext extends BaseOpContext implements OpContext {
|
|||
super.setExecutionMode(mode);
|
||||
nativeOps.ctxSetExecutionMode(context, mode.ordinal());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void purge() {
|
||||
super.purge();
|
||||
nativeOps.ctxPurge(context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getUniqueId() {
|
||||
return new String("CTX_" + id);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Deallocator deallocator() {
|
||||
return new CpuOpContextDeallocator(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int targetDevice() {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.cpu.nativecpu.ops;
|
||||
|
||||
import org.nd4j.linalg.api.memory.Deallocator;
|
||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||
import org.nd4j.nativeblas.OpaqueContext;
|
||||
|
||||
public class CpuOpContextDeallocator implements Deallocator {
|
||||
private transient final OpaqueContext context;
|
||||
|
||||
public CpuOpContextDeallocator(CpuOpContext ctx) {
|
||||
context = (OpaqueContext) ctx.contextPointer();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deallocate() {
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().deleteGraphContext(context);
|
||||
}
|
||||
}
|
|
@ -3093,6 +3093,7 @@ public native OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext
|
|||
public native void ctxAllowHelpers(OpaqueContext ptr, @Cast("bool") boolean reallyAllow);
|
||||
public native void ctxShapeFunctionOverride(OpaqueContext ptr, @Cast("bool") boolean reallyOverride);
|
||||
public native void ctxSetExecutionMode(OpaqueContext ptr, int execMode);
|
||||
public native void ctxPurge(OpaqueContext ptr);
|
||||
public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace);
|
||||
public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer);
|
||||
public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
|
||||
|
@ -6456,6 +6457,13 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
|
|||
public native void setDArguments(@Cast("nd4j::DataType*") @StdVector IntBuffer dArgs);
|
||||
public native void setDArguments(@Cast("nd4j::DataType*") @StdVector int[] dArgs);
|
||||
|
||||
/**
|
||||
* This method purges fastpath in/out contents and releases all the handles.
|
||||
*
|
||||
* PLEASE NOTE: I/T/B/D args will stay intact
|
||||
*/
|
||||
public native void clearFastPath();
|
||||
|
||||
public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer);
|
||||
|
||||
public native void allowHelpers(@Cast("bool") boolean reallyAllow);
|
||||
|
|
|
@ -8262,6 +8262,31 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
assertArrayEquals(new long[]{10, 0}, out2.shape());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDealloc_1() throws Exception {
|
||||
|
||||
for (int e = 0; e < 5000; e++){
|
||||
try(val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace("someid")) {
|
||||
val x = Nd4j.createUninitialized(DataType.FLOAT, 1, 1000);
|
||||
//val y = x.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 100)).reshape('c', 10, 10);
|
||||
//val z = x.get(NDArrayIndex.point(0), NDArrayIndex.interval(100, 200)).reshape('c', 10, 10);
|
||||
//val a = x.get(NDArrayIndex.point(0), NDArrayIndex.interval(200, 300)).reshape('f', 10, 10);
|
||||
} finally {
|
||||
//System.gc();
|
||||
}
|
||||
}
|
||||
|
||||
Thread.sleep(1000);
|
||||
System.gc();
|
||||
|
||||
Thread.sleep(1000);
|
||||
System.gc();
|
||||
System.gc();
|
||||
System.gc();
|
||||
|
||||
//Nd4j.getMemoryManager().printRemainingStacks();
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
|
|
Loading…
Reference in New Issue