Few Workspace tweaks (#409)
* Allow to destroy workspaces on demand Signed-off-by: raver119 <raver119@gmail.com> * MMAP'd workspace can't have LearningPolicy Signed-off-by: raver119 <raver119@gmail.com> * throw an exception on CUDA Signed-off-by: raver119 <raver119@gmail.com>master
parent
f5f77df846
commit
5e779574cb
|
@ -138,13 +138,6 @@ void _mmap(Nd4jLong* result, size_t length, const char *fileName) {
|
|||
OffsetType off = 0;
|
||||
int prot = PROT_READ | PROT_WRITE;
|
||||
|
||||
// we need to convert long path (probably) to short pat (actually)
|
||||
// it's Windows API, in the middle of 2018!
|
||||
auto sz = GetShortPathName(fileName, nullptr, 0);
|
||||
|
||||
auto shortName = new TCHAR[sz];
|
||||
GetShortPathName(fileName, shortName, sz);
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable: 4293)
|
||||
|
@ -170,8 +163,6 @@ void _mmap(Nd4jLong* result, size_t length, const char *fileName) {
|
|||
|
||||
h = CreateFileA(fileName, GENERIC_READ | GENERIC_WRITE, FILE_SHARE_WRITE | FILE_SHARE_READ, nullptr, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, nullptr);
|
||||
|
||||
delete[] shortName;
|
||||
|
||||
if (h == INVALID_HANDLE_VALUE) {
|
||||
errno = __map_mman_error(GetLastError(), EPERM);
|
||||
nd4j_printf("Error code: %i\n", (int) errno);
|
||||
|
|
|
@ -297,8 +297,12 @@ public abstract class Nd4jWorkspace implements MemoryWorkspace {
|
|||
}
|
||||
|
||||
protected void init() {
|
||||
// in case of MMAP we don't want any learning applied
|
||||
if (workspaceConfiguration.getPolicyLocation() == LocationPolicy.MMAP && workspaceConfiguration.getPolicyLearning() != LearningPolicy.NONE)
|
||||
throw new IllegalArgumentException("Workspace backed by memory-mapped file can't have LearningPolicy defined");
|
||||
|
||||
// we don't want overallocation in case of MMAP
|
||||
if (currentSize.get() > 0 && workspaceConfiguration.getPolicyLocation() != LocationPolicy.MMAP) {
|
||||
if (currentSize.get() > 0 && workspaceConfiguration.getPolicyLocation() != LocationPolicy.MMAP) {
|
||||
if (!isOver.get()) {
|
||||
if (workspaceConfiguration.getPolicyAllocation() == AllocationPolicy.OVERALLOCATE
|
||||
&& workspaceConfiguration.getOverallocationLimit() > 0) {
|
||||
|
@ -309,7 +313,6 @@ public abstract class Nd4jWorkspace implements MemoryWorkspace {
|
|||
|
||||
if (workspaceConfiguration.getMaxSize() > 0 && currentSize.get() > workspaceConfiguration.getMaxSize())
|
||||
currentSize.set(workspaceConfiguration.getMaxSize());
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -159,7 +159,7 @@ public abstract class BasicWorkspaceManager implements MemoryWorkspaceManager {
|
|||
if (workspace == null || workspace instanceof DummyWorkspace)
|
||||
return;
|
||||
|
||||
//workspace.destroyWorkspace();
|
||||
workspace.destroyWorkspace(true);
|
||||
backingMap.get().remove(workspace.getId());
|
||||
}
|
||||
|
||||
|
|
|
@ -138,6 +138,13 @@ public class CpuWorkspace extends Nd4jWorkspace implements Deallocatable {
|
|||
}
|
||||
}
|
||||
|
||||
protected long mappedFileSize() {
|
||||
if (workspaceConfiguration.getPolicyLocation() != LocationPolicy.MMAP)
|
||||
return 0;
|
||||
|
||||
return tempFile.length();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void clearExternalAllocations() {
|
||||
if (isDebug.get())
|
||||
|
|
|
@ -18,11 +18,15 @@ package org.nd4j.linalg.cpu.nativecpu.workspace;
|
|||
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.bytedeco.javacpp.LongPointer;
|
||||
import org.bytedeco.javacpp.Pointer;
|
||||
import org.nd4j.linalg.api.memory.Deallocator;
|
||||
import org.nd4j.linalg.api.memory.enums.LocationPolicy;
|
||||
import org.nd4j.linalg.api.memory.enums.MemoryKind;
|
||||
import org.nd4j.linalg.api.memory.pointers.PointersPair;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Queue;
|
||||
|
@ -37,12 +41,16 @@ public class CpuWorkspaceDeallocator implements Deallocator {
|
|||
private Queue<PointersPair> pinnedPointers;
|
||||
private List<PointersPair> externalPointers;
|
||||
private LocationPolicy location;
|
||||
private Pair<LongPointer, Long> mmapInfo;
|
||||
|
||||
public CpuWorkspaceDeallocator(@NonNull CpuWorkspace workspace) {
|
||||
this.pointersPair = workspace.workspace();
|
||||
this.pinnedPointers = workspace.pinnedPointers();
|
||||
this.externalPointers = workspace.externalPointers();
|
||||
this.location = workspace.getWorkspaceConfiguration().getPolicyLocation();
|
||||
|
||||
if (workspace.mappedFileSize() > 0)
|
||||
this.mmapInfo = Pair.makePair(workspace.mmap, workspace.mappedFileSize());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -50,7 +58,7 @@ public class CpuWorkspaceDeallocator implements Deallocator {
|
|||
log.trace("Deallocating CPU workspace");
|
||||
|
||||
// purging workspace planes
|
||||
if (pointersPair != null) {
|
||||
if (pointersPair != null && (pointersPair.getDevicePointer() != null || pointersPair.getHostPointer() != null)) {
|
||||
if (pointersPair.getDevicePointer() != null) {
|
||||
Nd4j.getMemoryManager().release(pointersPair.getDevicePointer(), MemoryKind.DEVICE);
|
||||
}
|
||||
|
@ -58,6 +66,8 @@ public class CpuWorkspaceDeallocator implements Deallocator {
|
|||
if (pointersPair.getHostPointer() != null) {
|
||||
if (location != LocationPolicy.MMAP)
|
||||
Nd4j.getMemoryManager().release(pointersPair.getHostPointer(), MemoryKind.HOST);
|
||||
else
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().munmapFile(null, mmapInfo.getFirst(), mmapInfo.getSecond());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -930,6 +930,7 @@ public class BasicWorkspaceTests extends BaseNd4jTest {
|
|||
WorkspaceConfiguration mmap = WorkspaceConfiguration.builder()
|
||||
.initialSize(1000000)
|
||||
.policyLocation(LocationPolicy.MMAP)
|
||||
.policyLearning(LearningPolicy.NONE)
|
||||
.build();
|
||||
|
||||
MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(mmap, "M2");
|
||||
|
|
|
@ -344,6 +344,7 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
|
|||
.initialSize(200 * 1024L * 1024L) // 200mbs
|
||||
.tempFilePath(tmpFile.toAbsolutePath().toString())
|
||||
.policyLocation(LocationPolicy.MMAP)
|
||||
.policyLearning(LearningPolicy.NONE)
|
||||
.build();
|
||||
|
||||
try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(mmap, "M2")) {
|
||||
|
@ -373,6 +374,7 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
|
|||
.initialSize(200 * 1024L * 1024L) // 200mbs
|
||||
.tempFilePath(tmpFile.toAbsolutePath().toString())
|
||||
.policyLocation(LocationPolicy.MMAP)
|
||||
.policyLearning(LearningPolicy.NONE)
|
||||
.build();
|
||||
|
||||
try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(mmap, "M2")) {
|
||||
|
@ -380,6 +382,49 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDeleteMappedFile_1() throws Exception {
|
||||
if (!Nd4j.getEnvironment().isCPU())
|
||||
return;
|
||||
|
||||
val tmpFile = Files.createTempFile("some", "file");
|
||||
val mmap = WorkspaceConfiguration.builder()
|
||||
.initialSize(200 * 1024L * 1024L) // 200mbs
|
||||
.tempFilePath(tmpFile.toAbsolutePath().toString())
|
||||
.policyLocation(LocationPolicy.MMAP)
|
||||
.policyLearning(LearningPolicy.NONE)
|
||||
.build();
|
||||
|
||||
try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(mmap, "M2")) {
|
||||
val x = Nd4j.rand(DataType.FLOAT, 1024);
|
||||
}
|
||||
|
||||
Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
|
||||
|
||||
Files.delete(tmpFile);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void testDeleteMappedFile_2() throws Exception {
|
||||
if (!Nd4j.getEnvironment().isCPU())
|
||||
throw new IllegalArgumentException("Don't try to run on CUDA");
|
||||
|
||||
val tmpFile = Files.createTempFile("some", "file");
|
||||
val mmap = WorkspaceConfiguration.builder()
|
||||
.initialSize(200 * 1024L * 1024L) // 200mbs
|
||||
.tempFilePath(tmpFile.toAbsolutePath().toString())
|
||||
.policyLocation(LocationPolicy.MMAP)
|
||||
.build();
|
||||
|
||||
try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(mmap, "M2")) {
|
||||
val x = Nd4j.rand(DataType.FLOAT, 1024);
|
||||
}
|
||||
|
||||
Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
|
||||
|
||||
Files.delete(tmpFile);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
|
|
Loading…
Reference in New Issue