Playing with some new code 2 - clean build/test

Signed-off-by: brian <brian@brutex.de>
master
Brian Rosenberger 2023-04-17 09:41:12 +02:00
parent 9d4939ccfd
commit 82e65bdf59
15 changed files with 473 additions and 359 deletions

View File

@ -67,7 +67,7 @@ public interface WorkspaceMgr<T extends Enum<T>> {
/** /**
* Set arrays to be scoped out (not in any workspace) for the specified array type. * Set arrays to be scoped out (not in any workspace) for the specified array type.
* This means that create, dup, leverage etc methods will return result arrays that are not attached to any workspace * This means that create, dup, leverage etc. methods will return result arrays that are not attached to any workspace
* *
* @param arrayType Array type to set scoped out for * @param arrayType Array type to set scoped out for
*/ */
@ -120,7 +120,7 @@ public interface WorkspaceMgr<T extends Enum<T>> {
boolean isWorkspaceOpen(T arrayType); boolean isWorkspaceOpen(T arrayType);
/** /**
* Assert thath the workspace for the specified array type is open. * Assert that the workspace for the specified array type is open.
* For array types that are set to scoped out, this will be treated as a no-op * For array types that are set to scoped out, this will be treated as a no-op
* @param arrayType Array type to check * @param arrayType Array type to check
* @param msg May be null. If non-null: include this in the exception * @param msg May be null. If non-null: include this in the exception
@ -129,7 +129,7 @@ public interface WorkspaceMgr<T extends Enum<T>> {
void assertOpen(T arrayType, String msg) throws ND4JWorkspaceException; void assertOpen(T arrayType, String msg) throws ND4JWorkspaceException;
/** /**
* Assert thath the workspace for the specified array type is not open. * Assert that the workspace for the specified array type is not open.
* For array types that are set to scoped out, this will be treated as a no-op * For array types that are set to scoped out, this will be treated as a no-op
* @param arrayType Array type to check * @param arrayType Array type to check
* @param msg May be null. If non-null: include this in the exception * @param msg May be null. If non-null: include this in the exception
@ -193,7 +193,7 @@ public interface WorkspaceMgr<T extends Enum<T>> {
/** /**
* Create an uninitialized array in the specified array type's workspace (or detached if none is specified). * Create an uninitialized array in the specified array type's workspace (or detached if none is specified).
* Equivalent to {@link org.nd4j.linalg.factory.Nd4j#createUninitialized(int)} (int...)}, other than the array location * Equivalent to {@link org.nd4j.linalg.factory.Nd4j#createUninitialized(int...)}, other than the array location
* @param arrayType Array type * @param arrayType Array type
* @param dataType Data type of the created array * @param dataType Data type of the created array
* @param shape Shape * @param shape Shape
@ -231,7 +231,7 @@ public interface WorkspaceMgr<T extends Enum<T>> {
/** /**
* Cast the specified array to the specified datatype.<br> * Cast the specified array to the specified datatype.<br>
* If the array is already the correct type, the bahaviour depends on the 'dupIfCorrectType' argument.<br> * If the array is already the correct type, the behaviour depends on the 'dupIfCorrectType' argument.<br>
* dupIfCorrectType = false && toCast.dataType() == dataType: return input array as-is (unless workspace is wrong)<br> * dupIfCorrectType = false && toCast.dataType() == dataType: return input array as-is (unless workspace is wrong)<br>
* dupIfCorrectType = true && toCast.dataType() == dataType: duplicate the array into the specified workspace<br> * dupIfCorrectType = true && toCast.dataType() == dataType: duplicate the array into the specified workspace<br>
* @param arrayType Array type * @param arrayType Array type

View File

@ -81,7 +81,7 @@ public class EvaluationToolsTests extends BaseDL4JTest {
String str = EvaluationTools.rocChartToHtml(roc); String str = EvaluationTools.rocChartToHtml(roc);
// System.out.println(str); System.out.println(str);
} }
} }

View File

@ -58,6 +58,8 @@ public class Cnn3DLossLayer extends FeedForwardLayer {
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
setNetConfiguration(conf); setNetConfiguration(conf);
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
runInheritance();
org.deeplearning4j.nn.layers.convolution.Cnn3DLossLayer ret = org.deeplearning4j.nn.layers.convolution.Cnn3DLossLayer ret =
new org.deeplearning4j.nn.layers.convolution.Cnn3DLossLayer(lconf, networkDataType); new org.deeplearning4j.nn.layers.convolution.Cnn3DLossLayer(lconf, networkDataType);
ret.addTrainingListeners(trainingListeners); ret.addTrainingListeners(trainingListeners);

View File

@ -63,6 +63,8 @@ public class CnnLossLayer extends FeedForwardLayer {
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
setNetConfiguration(conf); setNetConfiguration(conf);
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
runInheritance();
org.deeplearning4j.nn.layers.convolution.CnnLossLayer ret = org.deeplearning4j.nn.layers.convolution.CnnLossLayer ret =
new org.deeplearning4j.nn.layers.convolution.CnnLossLayer(lconf, networkDataType); new org.deeplearning4j.nn.layers.convolution.CnnLossLayer(lconf, networkDataType);
ret.addTrainingListeners(trainingListeners); ret.addTrainingListeners(trainingListeners);

View File

@ -77,7 +77,11 @@ public class GravesLSTM extends AbstractLSTM {
public Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners, public Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners,
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
LayerValidation.assertNInNOutSet("GravesLSTM", getLayerName(), layerIndex, getNIn(), getNOut()); LayerValidation.assertNInNOutSet("GravesLSTM", getLayerName(), layerIndex, getNIn(), getNOut());
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
lconf.setNetConfiguration(conf);
runInheritance();
org.deeplearning4j.nn.layers.recurrent.GravesLSTM ret = org.deeplearning4j.nn.layers.recurrent.GravesLSTM ret =
new org.deeplearning4j.nn.layers.recurrent.GravesLSTM(lconf, networkDataType); new org.deeplearning4j.nn.layers.recurrent.GravesLSTM(lconf, networkDataType);

View File

@ -61,6 +61,8 @@ public class RnnLossLayer extends FeedForwardLayer {
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
lconf.setNetConfiguration(conf);
runInheritance();
org.deeplearning4j.nn.layers.recurrent.RnnLossLayer ret = org.deeplearning4j.nn.layers.recurrent.RnnLossLayer ret =
new org.deeplearning4j.nn.layers.recurrent.RnnLossLayer(lconf, networkDataType); new org.deeplearning4j.nn.layers.recurrent.RnnLossLayer(lconf, networkDataType);

View File

@ -135,6 +135,7 @@ public class SubsamplingLayer extends NoParamLayer {
Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView, Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView,
boolean initializeParams, DataType networkDataType) { boolean initializeParams, DataType networkDataType) {
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
runInheritance();
org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingLayer ret = org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingLayer ret =
new org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingLayer(lconf, networkDataType); new org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingLayer(lconf, networkDataType);

View File

@ -24,6 +24,8 @@ import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors;
import lombok.*; import lombok.*;
import net.brutex.ai.dnn.api.IModel; import net.brutex.ai.dnn.api.IModel;
import org.deeplearning4j.nn.api.ITraininableLayerConfiguration; import org.deeplearning4j.nn.api.ITraininableLayerConfiguration;
@ -328,13 +330,9 @@ public abstract class AbstractLayer<LayerConf_T extends LayerConfiguration> impl
@Override @Override
public void clearNoiseWeightParams() {} public void clearNoiseWeightParams() {}
public List<String> variables() { public List<String> getVariables(boolean copy) {
return variables;
}
public List<String> variables(boolean copy) {
if (copy) { if (copy) {
return variables(); return new ArrayList<>(getVariables());
} }
return variables; return variables;
} }

View File

@ -184,6 +184,17 @@ public abstract class BaseLayer<LayerConfT extends BaseLayerConfiguration>
setParams(params, 'f'); setParams(params, 'f');
} }
/**
* * The AbstractLayer does not implement Params, ParamTable and GradientView. A RuntimeException
* * will be triggered when calling this.
*
* @return 1d parameter vector
*/
@Override
public INDArray getParams() {
return paramsFlattened;
}
/** */ /** */
@Override @Override
public void close() {} public void close() {}
@ -358,7 +369,7 @@ public abstract class BaseLayer<LayerConfT extends BaseLayerConfiguration>
protected void setParams(INDArray params, char order) { protected void setParams(INDArray params, char order) {
if (params == null) { if (params == null) {
log.warn( log.trace(
"setParams(INDArray params, char order): params is null. Skipping setParams in Layer {}[{}] at index {}", "setParams(INDArray params, char order): params is null. Skipping setParams in Layer {}[{}] at index {}",
getLayerConfiguration().getLayerName(), getLayerConfiguration().getLayerName(),
getClass().getSimpleName(), getClass().getSimpleName(),

View File

@ -110,14 +110,14 @@ public class DefaultParamInitializer extends AbstractParamInitializer {
INDArray weightView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, nWeightParams)); INDArray weightView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, nWeightParams));
params.put(WEIGHT_KEY, createWeightMatrix(layerConf, weightView, initializeParams)); params.put(WEIGHT_KEY, createWeightMatrix(layerConf, weightView, initializeParams));
layerConf.getNetConfiguration().addNetWideVariable(WEIGHT_KEY); layerConf.addVariable(WEIGHT_KEY);
long offset = nWeightParams; long offset = nWeightParams;
if(hasBias(layerConf)){ if(hasBias(layerConf)){
INDArray biasView = paramsView.get(NDArrayIndex.interval(0,0,true), INDArray biasView = paramsView.get(NDArrayIndex.interval(0,0,true),
NDArrayIndex.interval(offset, offset + nOut)); NDArrayIndex.interval(offset, offset + nOut));
params.put(BIAS_KEY, createBias(layerConf, biasView, initializeParams)); params.put(BIAS_KEY, createBias(layerConf, biasView, initializeParams));
layerConf.getNetConfiguration().addNetWideVariable(BIAS_KEY); layerConf.addVariable(BIAS_KEY);
offset += nOut; offset += nOut;
} }
@ -125,7 +125,7 @@ public class DefaultParamInitializer extends AbstractParamInitializer {
INDArray gainView = paramsView.get(NDArrayIndex.interval(0,0,true), INDArray gainView = paramsView.get(NDArrayIndex.interval(0,0,true),
NDArrayIndex.interval(offset, offset + nOut)); NDArrayIndex.interval(offset, offset + nOut));
params.put(GAIN_KEY, createGain(conf, gainView, initializeParams)); params.put(GAIN_KEY, createGain(conf, gainView, initializeParams));
conf.getNetConfiguration().addNetWideVariable(GAIN_KEY); conf.addVariable(GAIN_KEY);
} }
return params; return params;

View File

@ -50,7 +50,8 @@ public class CpuOpContext extends BaseOpContext implements OpContext, Deallocata
@Override @Override
public void close() { public void close() {
// no-op nativeOps.ctxPurge(context);
context.deallocate();
} }
@Override @Override

View File

@ -20,6 +20,8 @@
package org.nd4j.jita.workspace; package org.nd4j.jita.workspace;
import java.util.List;
import java.util.Queue;
import lombok.NonNull; import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
@ -39,10 +41,6 @@ import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.nativeblas.NativeOpsHolder; import org.nd4j.nativeblas.NativeOpsHolder;
import java.util.List;
import java.util.Queue;
/** /**
* CUDA-aware MemoryWorkspace implementation * CUDA-aware MemoryWorkspace implementation
* *
@ -51,7 +49,6 @@ import java.util.Queue;
@Slf4j @Slf4j
public class CudaWorkspace extends Nd4jWorkspace { public class CudaWorkspace extends Nd4jWorkspace {
public CudaWorkspace(@NonNull WorkspaceConfiguration configuration) { public CudaWorkspace(@NonNull WorkspaceConfiguration configuration) {
super(configuration); super(configuration);
} }
@ -60,7 +57,10 @@ public class CudaWorkspace extends Nd4jWorkspace {
super(configuration, workspaceId); super(configuration, workspaceId);
} }
public CudaWorkspace(@NonNull WorkspaceConfiguration configuration, @NonNull String workspaceId, Integer deviceId) { public CudaWorkspace(
@NonNull WorkspaceConfiguration configuration,
@NonNull String workspaceId,
Integer deviceId) {
super(configuration, workspaceId); super(configuration, workspaceId);
this.deviceId = deviceId; this.deviceId = deviceId;
} }
@ -74,29 +74,39 @@ public class CudaWorkspace extends Nd4jWorkspace {
super.init(); super.init();
if (currentSize.get() > 0) { if (currentSize.get() > 0) {
//log.info("Allocating {} bytes at DEVICE & HOST space...", currentSize.get()); log.debug("Allocating {} bytes at DEVICE & HOST space...", currentSize.get());
isInit.set(true); isInit.set(true);
long bytes = currentSize.get(); long bytes = currentSize.get();
if (isDebug.get()) log.debug(
log.info("Allocating [{}] workspace on device_{}, {} bytes...", id, Nd4j.getAffinityManager().getDeviceForCurrentThread(), bytes); "Allocating [{}] workspace on device_{}, {} bytes...",
id,
Nd4j.getAffinityManager().getDeviceForCurrentThread(),
bytes);
if (isDebug.get()) { if (isDebug.get()) {
Nd4j.getWorkspaceManager().printAllocationStatisticsForCurrentThread(); Nd4j.getWorkspaceManager().printAllocationStatisticsForCurrentThread();
} }
Pointer ptr = memoryManager.allocate((bytes + SAFETY_OFFSET), MemoryKind.HOST, false); Pointer ptr = memoryManager.allocate((bytes + SAFETY_OFFSET), MemoryKind.HOST, false);
if (ptr == null) if (ptr == null) throw new ND4JIllegalStateException("Can't allocate memory for workspace");
throw new ND4JIllegalStateException("Can't allocate memory for workspace");
workspace.setHostPointer(new PagedPointer(ptr)); workspace.setHostPointer(new PagedPointer(ptr));
if (workspaceConfiguration.getPolicyMirroring() != MirroringPolicy.HOST_ONLY) { if (workspaceConfiguration.getPolicyMirroring() != MirroringPolicy.HOST_ONLY) {
workspace.setDevicePointer(new PagedPointer(memoryManager.allocate((bytes + SAFETY_OFFSET), MemoryKind.DEVICE, false))); workspace.setDevicePointer(
AllocationsTracker.getInstance().markAllocated(AllocationKind.GENERAL, Nd4j.getAffinityManager().getDeviceForCurrentThread(), bytes + SAFETY_OFFSET); new PagedPointer(
memoryManager.allocate((bytes + SAFETY_OFFSET), MemoryKind.DEVICE, false)));
AllocationsTracker.getInstance()
.markAllocated(
AllocationKind.GENERAL,
Nd4j.getAffinityManager().getDeviceForCurrentThread(),
bytes + SAFETY_OFFSET);
MemoryTracker.getInstance().incrementWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread(), bytes + SAFETY_OFFSET); MemoryTracker.getInstance()
.incrementWorkspaceAllocatedAmount(
Nd4j.getAffinityManager().getDeviceForCurrentThread(), bytes + SAFETY_OFFSET);
// if base pointer isn't aligned to 16 bytes (128 bits) - adjust the offfset then // if base pointer isn't aligned to 16 bytes (128 bits) - adjust the offfset then
val addr = workspace.getDevicePointer().address(); val addr = workspace.getDevicePointer().address();
@ -114,14 +124,12 @@ public class CudaWorkspace extends Nd4jWorkspace {
return this.alloc(requiredMemory, MemoryKind.DEVICE, type, initialize); return this.alloc(requiredMemory, MemoryKind.DEVICE, type, initialize);
} }
@Override @Override
public synchronized void destroyWorkspace(boolean extended) { public synchronized void destroyWorkspace(boolean extended) {
val size = currentSize.getAndSet(0); val size = currentSize.getAndSet(0);
reset(); reset();
if (extended) if (extended) clearExternalAllocations();
clearExternalAllocations();
clearPinnedAllocations(extended); clearPinnedAllocations(extended);
@ -129,20 +137,27 @@ public class CudaWorkspace extends Nd4jWorkspace {
NativeOpsHolder.getInstance().getDeviceNativeOps().freeHost(workspace.getHostPointer()); NativeOpsHolder.getInstance().getDeviceNativeOps().freeHost(workspace.getHostPointer());
if (workspace.getDevicePointer() != null) { if (workspace.getDevicePointer() != null) {
NativeOpsHolder.getInstance().getDeviceNativeOps().freeDevice(workspace.getDevicePointer(), 0); NativeOpsHolder.getInstance()
AllocationsTracker.getInstance().markReleased(AllocationKind.GENERAL, Nd4j.getAffinityManager().getDeviceForCurrentThread(), size + SAFETY_OFFSET); .getDeviceNativeOps()
.freeDevice(workspace.getDevicePointer(), 0);
AllocationsTracker.getInstance()
.markReleased(
AllocationKind.GENERAL,
Nd4j.getAffinityManager().getDeviceForCurrentThread(),
size + SAFETY_OFFSET);
MemoryTracker.getInstance().decrementWorkspaceAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread(), size + SAFETY_OFFSET); MemoryTracker.getInstance()
.decrementWorkspaceAmount(
Nd4j.getAffinityManager().getDeviceForCurrentThread(), size + SAFETY_OFFSET);
} }
workspace.setDevicePointer(null); workspace.setDevicePointer(null);
workspace.setHostPointer(null); workspace.setHostPointer(null);
} }
@Override @Override
public PagedPointer alloc(long requiredMemory, MemoryKind kind, DataType type, boolean initialize) { public PagedPointer alloc(
long requiredMemory, MemoryKind kind, DataType type, boolean initialize) {
long numElements = requiredMemory / Nd4j.sizeOfDataType(type); long numElements = requiredMemory / Nd4j.sizeOfDataType(type);
// alignment // alignment
@ -151,48 +166,74 @@ public class CudaWorkspace extends Nd4jWorkspace {
if (!isUsed.get()) { if (!isUsed.get()) {
if (disabledCounter.incrementAndGet() % 10 == 0) if (disabledCounter.incrementAndGet() % 10 == 0)
log.warn("Worskpace was turned off, and wasn't enabled after {} allocations", disabledCounter.get()); log.warn(
"Workspace was turned off, and wasn't enabled after {} allocations",
disabledCounter.get());
if (kind == MemoryKind.DEVICE) { if (kind == MemoryKind.DEVICE) {
val pointer = new PagedPointer(memoryManager.allocate(requiredMemory, MemoryKind.DEVICE, initialize), numElements); val pointer =
new PagedPointer(
memoryManager.allocate(requiredMemory, MemoryKind.DEVICE, initialize), numElements);
externalAllocations.add(new PointersPair(null, pointer)); externalAllocations.add(new PointersPair(null, pointer));
MemoryTracker.getInstance().incrementWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread(), requiredMemory); MemoryTracker.getInstance()
.incrementWorkspaceAllocatedAmount(
Nd4j.getAffinityManager().getDeviceForCurrentThread(), requiredMemory);
return pointer; return pointer;
} else { } else {
val pointer = new PagedPointer(memoryManager.allocate(requiredMemory, MemoryKind.HOST, initialize), numElements); val pointer =
new PagedPointer(
memoryManager.allocate(requiredMemory, MemoryKind.HOST, initialize), numElements);
externalAllocations.add(new PointersPair(pointer, null)); externalAllocations.add(new PointersPair(pointer, null));
return pointer; return pointer;
} }
} }
boolean trimmer = (workspaceConfiguration.getPolicyReset() == ResetPolicy.ENDOFBUFFER_REACHED && requiredMemory + cycleAllocations.get() > initialBlockSize.get() && initialBlockSize.get() > 0 && kind == MemoryKind.DEVICE) || trimmedMode.get(); boolean trimmer =
(workspaceConfiguration.getPolicyReset() == ResetPolicy.ENDOFBUFFER_REACHED
&& requiredMemory + cycleAllocations.get() > initialBlockSize.get()
&& initialBlockSize.get() > 0
&& kind == MemoryKind.DEVICE)
|| trimmedMode.get();
if (trimmer && workspaceConfiguration.getPolicySpill() == SpillPolicy.REALLOCATE && !trimmedMode.get()) { if (trimmer
&& workspaceConfiguration.getPolicySpill() == SpillPolicy.REALLOCATE
&& !trimmedMode.get()) {
trimmedMode.set(true); trimmedMode.set(true);
trimmedStep.set(stepsCount.get()); trimmedStep.set(stepsCount.get());
} }
if (kind == MemoryKind.DEVICE) { if (kind == MemoryKind.DEVICE) {
if (deviceOffset.get() + requiredMemory <= currentSize.get() && !trimmer && Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.SPILL_EVERYTHING) { if (deviceOffset.get() + requiredMemory <= currentSize.get()
&& !trimmer
&& Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.SPILL_EVERYTHING) {
cycleAllocations.addAndGet(requiredMemory); cycleAllocations.addAndGet(requiredMemory);
long prevOffset = deviceOffset.getAndAdd(requiredMemory); long prevOffset = deviceOffset.getAndAdd(requiredMemory);
if (workspaceConfiguration.getPolicyMirroring() == MirroringPolicy.HOST_ONLY) if (workspaceConfiguration.getPolicyMirroring() == MirroringPolicy.HOST_ONLY) return null;
return null;
val ptr = workspace.getDevicePointer().withOffset(prevOffset, numElements); val ptr = workspace.getDevicePointer().withOffset(prevOffset, numElements);
if (isDebug.get()) log.debug(
log.info("Workspace [{}] device_{}: alloc array of {} bytes, capacity of {} elements; prevOffset: {}; newOffset: {}; size: {}; address: {}", id, Nd4j.getAffinityManager().getDeviceForCurrentThread(), requiredMemory, numElements, prevOffset, deviceOffset.get(), currentSize.get(), ptr.address()); "Workspace [{}] device_{}: alloc array of {} bytes, capacity of {} elements; prevOffset: {}; newOffset: {}; size: {}; address: {}",
id,
Nd4j.getAffinityManager().getDeviceForCurrentThread(),
requiredMemory,
numElements,
prevOffset,
deviceOffset.get(),
currentSize.get(),
ptr.address());
if (initialize) { if (initialize) {
val context = AtomicAllocator.getInstance().getDeviceContext(); val context = AtomicAllocator.getInstance().getDeviceContext();
int ret = NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(ptr, 0, requiredMemory, 0, context.getSpecialStream()); int ret =
NativeOpsHolder.getInstance()
.getDeviceNativeOps()
.memsetAsync(ptr, 0, requiredMemory, 0, context.getSpecialStream());
if (ret == 0) if (ret == 0)
throw new ND4JIllegalStateException("memset failed device_" + Nd4j.getAffinityManager().getDeviceForCurrentThread()); throw new ND4JIllegalStateException(
"memset failed device_" + Nd4j.getAffinityManager().getDeviceForCurrentThread());
context.syncSpecialStream(); context.syncSpecialStream();
} }
@ -201,28 +242,34 @@ public class CudaWorkspace extends Nd4jWorkspace {
} else { } else {
// spill // spill
if (workspaceConfiguration.getPolicyReset() == ResetPolicy.ENDOFBUFFER_REACHED && currentSize.get() > 0 && !trimmer && Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.SPILL_EVERYTHING) { if (workspaceConfiguration.getPolicyReset() == ResetPolicy.ENDOFBUFFER_REACHED
//log.info("End of space reached. Current offset: {}; requiredMemory: {}", deviceOffset.get(), requiredMemory); && currentSize.get() > 0
&& !trimmer
&& Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.SPILL_EVERYTHING) {
// log.info("End of space reached. Current offset: {}; requiredMemory: {}",
// deviceOffset.get(), requiredMemory);
deviceOffset.set(0); deviceOffset.set(0);
resetPlanned.set(true); resetPlanned.set(true);
return alloc(requiredMemory, kind, type, initialize); return alloc(requiredMemory, kind, type, initialize);
} }
if (!trimmer) if (!trimmer) spilledAllocationsSize.addAndGet(requiredMemory);
spilledAllocationsSize.addAndGet(requiredMemory); else pinnedAllocationsSize.addAndGet(requiredMemory);
else
pinnedAllocationsSize.addAndGet(requiredMemory);
if (isDebug.get()) { log.debug(
log.info("Workspace [{}] device_{}: spilled DEVICE array of {} bytes, capacity of {} elements", id, Nd4j.getAffinityManager().getDeviceForCurrentThread(), requiredMemory, numElements); "Workspace [{}] device_{}: spilled DEVICE array of {} bytes, capacity of {} elements",
} id,
Nd4j.getAffinityManager().getDeviceForCurrentThread(),
requiredMemory,
numElements);
val shape = new AllocationShape(requiredMemory / Nd4j.sizeOfDataType(type), Nd4j.sizeOfDataType(type), type); val shape =
new AllocationShape(
requiredMemory / Nd4j.sizeOfDataType(type), Nd4j.sizeOfDataType(type), type);
cycleAllocations.addAndGet(requiredMemory); cycleAllocations.addAndGet(requiredMemory);
if (workspaceConfiguration.getPolicyMirroring() == MirroringPolicy.HOST_ONLY) if (workspaceConfiguration.getPolicyMirroring() == MirroringPolicy.HOST_ONLY) return null;
return null;
switch (workspaceConfiguration.getPolicySpill()) { switch (workspaceConfiguration.getPolicySpill()) {
case REALLOCATE: case REALLOCATE:
@ -230,74 +277,104 @@ public class CudaWorkspace extends Nd4jWorkspace {
if (!trimmer) { if (!trimmer) {
externalCount.incrementAndGet(); externalCount.incrementAndGet();
// //
//AtomicAllocator.getInstance().getMemoryHandler().getMemoryProvider().malloc(shape, null, AllocationStatus.DEVICE).getDevicePointer() // AtomicAllocator.getInstance().getMemoryHandler().getMemoryProvider().malloc(shape,
val pointer = new PagedPointer(memoryManager.allocate(requiredMemory, MemoryKind.DEVICE, initialize), numElements); // null, AllocationStatus.DEVICE).getDevicePointer()
val pointer =
new PagedPointer(
memoryManager.allocate(requiredMemory, MemoryKind.DEVICE, initialize),
numElements);
pointer.isLeaked(); pointer.isLeaked();
val pp = new PointersPair(null, pointer); val pp = new PointersPair(null, pointer);
pp.setRequiredMemory(requiredMemory); pp.setRequiredMemory(requiredMemory);
externalAllocations.add(pp); externalAllocations.add(pp);
MemoryTracker.getInstance().incrementWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread(), requiredMemory); MemoryTracker.getInstance()
.incrementWorkspaceAllocatedAmount(
Nd4j.getAffinityManager().getDeviceForCurrentThread(), requiredMemory);
return pointer; return pointer;
} else { } else {
pinnedCount.incrementAndGet(); pinnedCount.incrementAndGet();
val pointer = new PagedPointer(memoryManager.allocate(requiredMemory, MemoryKind.DEVICE, initialize), numElements); val pointer =
new PagedPointer(
memoryManager.allocate(requiredMemory, MemoryKind.DEVICE, initialize),
numElements);
pointer.isLeaked(); pointer.isLeaked();
pinnedAllocations.add(new PointersPair(stepsCount.get(), requiredMemory, null, pointer)); pinnedAllocations.add(
MemoryTracker.getInstance().incrementWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread(), requiredMemory); new PointersPair(stepsCount.get(), requiredMemory, null, pointer));
MemoryTracker.getInstance()
.incrementWorkspaceAllocatedAmount(
Nd4j.getAffinityManager().getDeviceForCurrentThread(), requiredMemory);
return pointer; return pointer;
} }
case FAIL: case FAIL:
default: { default:
{
throw new ND4JIllegalStateException("Can't allocate memory: Workspace is full"); throw new ND4JIllegalStateException("Can't allocate memory: Workspace is full");
} }
} }
} }
} else if (kind == MemoryKind.HOST) { } else if (kind == MemoryKind.HOST) {
if (hostOffset.get() + requiredMemory <= currentSize.get() && !trimmer && Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.SPILL_EVERYTHING) { if (hostOffset.get() + requiredMemory <= currentSize.get()
&& !trimmer
&& Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.SPILL_EVERYTHING) {
long prevOffset = hostOffset.getAndAdd(requiredMemory); long prevOffset = hostOffset.getAndAdd(requiredMemory);
val ptr = workspace.getHostPointer().withOffset(prevOffset, numElements); val ptr = workspace.getHostPointer().withOffset(prevOffset, numElements);
// && workspaceConfiguration.getPolicyMirroring() == MirroringPolicy.HOST_ONLY // && workspaceConfiguration.getPolicyMirroring() == MirroringPolicy.HOST_ONLY
if (initialize) if (initialize) Pointer.memset(ptr, 0, requiredMemory);
Pointer.memset(ptr, 0, requiredMemory);
return ptr; return ptr;
} else { } else {
// log.info("Spilled HOST array of {} bytes, capacity of {} elements", requiredMemory, numElements); // log.info("Spilled HOST array of {} bytes, capacity of {} elements", requiredMemory,
if (workspaceConfiguration.getPolicyReset() == ResetPolicy.ENDOFBUFFER_REACHED && currentSize.get() > 0 && !trimmer && Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.SPILL_EVERYTHING) { // numElements);
//log.info("End of space reached. Current offset: {}; requiredMemory: {}", deviceOffset.get(), requiredMemory); if (workspaceConfiguration.getPolicyReset() == ResetPolicy.ENDOFBUFFER_REACHED
&& currentSize.get() > 0
&& !trimmer
&& Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.SPILL_EVERYTHING) {
// log.info("End of space reached. Current offset: {}; requiredMemory: {}",
// deviceOffset.get(), requiredMemory);
hostOffset.set(0); hostOffset.set(0);
// resetPlanned.set(true); // resetPlanned.set(true);
return alloc(requiredMemory, kind, type, initialize); return alloc(requiredMemory, kind, type, initialize);
} }
val shape = new AllocationShape(requiredMemory / Nd4j.sizeOfDataType(type), Nd4j.sizeOfDataType(type), type); val shape =
new AllocationShape(
requiredMemory / Nd4j.sizeOfDataType(type), Nd4j.sizeOfDataType(type), type);
switch (workspaceConfiguration.getPolicySpill()) { switch (workspaceConfiguration.getPolicySpill()) {
case REALLOCATE: case REALLOCATE:
case EXTERNAL: case EXTERNAL:
if (!trimmer) { if (!trimmer) {
// memoryManager.allocate(requiredMemory, MemoryKind.HOST, true) // memoryManager.allocate(requiredMemory, MemoryKind.HOST, true)
//AtomicAllocator.getInstance().getMemoryHandler().getMemoryProvider().malloc(shape, null, AllocationStatus.DEVICE).getDevicePointer() // AtomicAllocator.getInstance().getMemoryHandler().getMemoryProvider().malloc(shape,
PagedPointer pointer = new PagedPointer(memoryManager.allocate(requiredMemory, MemoryKind.HOST, initialize), numElements); // null, AllocationStatus.DEVICE).getDevicePointer()
PagedPointer pointer =
new PagedPointer(
memoryManager.allocate(requiredMemory, MemoryKind.HOST, initialize),
numElements);
externalAllocations.add(new PointersPair(pointer, null)); externalAllocations.add(new PointersPair(pointer, null));
return pointer; return pointer;
} else { } else {
//AtomicAllocator.getInstance().getMemoryHandler().getMemoryProvider().malloc(shape, null, AllocationStatus.DEVICE).getDevicePointer() // AtomicAllocator.getInstance().getMemoryHandler().getMemoryProvider().malloc(shape,
PagedPointer pointer = new PagedPointer(memoryManager.allocate(requiredMemory, MemoryKind.HOST, initialize), numElements); // null, AllocationStatus.DEVICE).getDevicePointer()
PagedPointer pointer =
new PagedPointer(
memoryManager.allocate(requiredMemory, MemoryKind.HOST, initialize),
numElements);
pointer.isLeaked(); pointer.isLeaked();
pinnedAllocations.add(new PointersPair(stepsCount.get(), 0L, pointer, null)); pinnedAllocations.add(new PointersPair(stepsCount.get(), 0L, pointer, null));
return pointer; return pointer;
} }
case FAIL: case FAIL:
default: { default:
{
throw new ND4JIllegalStateException("Can't allocate memory: Workspace is full"); throw new ND4JIllegalStateException("Can't allocate memory: Workspace is full");
} }
} }
@ -309,37 +386,40 @@ public class CudaWorkspace extends Nd4jWorkspace {
@Override @Override
protected void clearPinnedAllocations(boolean extended) { protected void clearPinnedAllocations(boolean extended) {
if (isDebug.get())
log.info("Workspace [{}] device_{} threadId {} cycle {}: clearing pinned allocations...", id, Nd4j.getAffinityManager().getDeviceForCurrentThread(), Thread.currentThread().getId(), cyclesCount.get()); log.debug(
"Workspace [{}] device_{} threadId {} cycle {}: clearing pinned allocations...",
id,
Nd4j.getAffinityManager().getDeviceForCurrentThread(),
Thread.currentThread().getId(),
cyclesCount.get());
while (!pinnedAllocations.isEmpty()) { while (!pinnedAllocations.isEmpty()) {
val pair = pinnedAllocations.peek(); val pair = pinnedAllocations.peek();
if (pair == null) if (pair == null) throw new RuntimeException();
throw new RuntimeException();
long stepNumber = pair.getAllocationCycle(); long stepNumber = pair.getAllocationCycle();
long stepCurrent = stepsCount.get(); long stepCurrent = stepsCount.get();
if (isDebug.get()) log.debug("Allocation step: {}; Current step: {}", stepNumber, stepCurrent);
log.info("Allocation step: {}; Current step: {}", stepNumber, stepCurrent);
if (stepNumber + 2 < stepCurrent || extended) { if (stepNumber + 2 < stepCurrent || extended) {
pinnedAllocations.remove(); pinnedAllocations.remove();
if (pair.getDevicePointer() != null) { if (pair.getDevicePointer() != null) {
NativeOpsHolder.getInstance().getDeviceNativeOps().freeDevice(pair.getDevicePointer(), 0); NativeOpsHolder.getInstance().getDeviceNativeOps().freeDevice(pair.getDevicePointer(), 0);
MemoryTracker.getInstance().decrementWorkspaceAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread(), pair.getRequiredMemory()); MemoryTracker.getInstance()
.decrementWorkspaceAmount(
Nd4j.getAffinityManager().getDeviceForCurrentThread(), pair.getRequiredMemory());
pinnedCount.decrementAndGet(); pinnedCount.decrementAndGet();
if (isDebug.get()) log.debug("deleting external device allocation ");
log.info("deleting external device allocation ");
} }
if (pair.getHostPointer() != null) { if (pair.getHostPointer() != null) {
NativeOpsHolder.getInstance().getDeviceNativeOps().freeHost(pair.getHostPointer()); NativeOpsHolder.getInstance().getDeviceNativeOps().freeHost(pair.getHostPointer());
if (isDebug.get()) log.debug("deleting external host allocation ");
log.info("deleting external host allocation ");
} }
val sizez = pair.getRequiredMemory() * -1; val sizez = pair.getRequiredMemory() * -1;
@ -352,8 +432,13 @@ public class CudaWorkspace extends Nd4jWorkspace {
@Override @Override
protected void clearExternalAllocations() { protected void clearExternalAllocations() {
if (isDebug.get())
log.info("Workspace [{}] device_{} threadId {} guid [{}]: clearing external allocations...", id, Nd4j.getAffinityManager().getDeviceForCurrentThread(), Thread.currentThread().getId(), guid); log.debug(
"Workspace [{}] device_{} threadId {} guid [{}]: clearing external allocations...",
id,
Nd4j.getAffinityManager().getDeviceForCurrentThread(),
Thread.currentThread().getId(),
guid);
Nd4j.getExecutioner().commit(); Nd4j.getExecutioner().commit();
@ -362,25 +447,34 @@ public class CudaWorkspace extends Nd4jWorkspace {
if (pair.getHostPointer() != null) { if (pair.getHostPointer() != null) {
NativeOpsHolder.getInstance().getDeviceNativeOps().freeHost(pair.getHostPointer()); NativeOpsHolder.getInstance().getDeviceNativeOps().freeHost(pair.getHostPointer());
if (isDebug.get()) log.debug("deleting external host allocation... ");
log.info("deleting external host allocation... ");
} }
if (pair.getDevicePointer() != null) { if (pair.getDevicePointer() != null) {
NativeOpsHolder.getInstance().getDeviceNativeOps().freeDevice(pair.getDevicePointer(), 0); NativeOpsHolder.getInstance().getDeviceNativeOps().freeDevice(pair.getDevicePointer(), 0);
if (isDebug.get()) log.debug("deleting external device allocation... ");
log.info("deleting external device allocation... ");
val sizez = pair.getRequiredMemory(); val sizez = pair.getRequiredMemory();
if (sizez != null) { if (sizez != null) {
AllocationsTracker.getInstance().markReleased(AllocationKind.GENERAL, Nd4j.getAffinityManager().getDeviceForCurrentThread(), sizez); AllocationsTracker.getInstance()
MemoryTracker.getInstance().decrementWorkspaceAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread(), sizez); .markReleased(
AllocationKind.GENERAL,
Nd4j.getAffinityManager().getDeviceForCurrentThread(),
sizez);
MemoryTracker.getInstance()
.decrementWorkspaceAmount(
Nd4j.getAffinityManager().getDeviceForCurrentThread(), sizez);
} }
} }
} }
} catch (Exception e) { } catch (Exception e) {
log.error("RC: Workspace [{}] device_{} threadId {} guid [{}]: clearing external allocations...", id, Nd4j.getAffinityManager().getDeviceForCurrentThread(), Thread.currentThread().getId(), guid); log.error(
"RC: Workspace [{}] device_{} threadId {} guid [{}]: clearing external allocations...",
id,
Nd4j.getAffinityManager().getDeviceForCurrentThread(),
Thread.currentThread().getId(),
guid);
throw new RuntimeException(e); throw new RuntimeException(e);
} }
@ -391,9 +485,7 @@ public class CudaWorkspace extends Nd4jWorkspace {
@Override @Override
protected void resetWorkspace() { protected void resetWorkspace() {
if (currentSize.get() < 1) { if (currentSize.get() < 1) {}
}
/* /*
if (Nd4j.getExecutioner() instanceof GridExecutioner) if (Nd4j.getExecutioner() instanceof GridExecutioner)

View File

@ -48,7 +48,7 @@ public class CudaWorkspaceDeallocator implements Deallocator {
@Override @Override
public void deallocate() { public void deallocate() {
log.trace("Deallocating CUDA workspace"); log.debug("Deallocating CUDA workspace");
// purging workspace planes // purging workspace planes
if (pointersPair != null) { if (pointersPair != null) {

View File

@ -1582,7 +1582,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
} }
if (nativeOps.lastErrorCode() != 0) if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage()); throw new RuntimeException(nativeOps.lastErrorMessage() + " error code: " + nativeOps.lastErrorCode());
profilingConfigurableHookOut(op, oc, st); profilingConfigurableHookOut(op, oc, st);

View File

@ -56,7 +56,8 @@ public class CudaOpContext extends BaseOpContext implements OpContext, Deallocat
@Override @Override
public void close() { public void close() {
// no-op nativeOps.ctxPurge(context);
context.deallocate();
} }
@Override @Override