Few Workspace tweaks (#409)

* Allow to destroy workspaces on demand

Signed-off-by: raver119 <raver119@gmail.com>

* MMAP'd workspace can't have LearningPolicy

Signed-off-by: raver119 <raver119@gmail.com>

* throw an exception on CUDA

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2020-04-23 15:50:30 +03:00 committed by GitHub
parent f5f77df846
commit 5e779574cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 70 additions and 13 deletions

View File

@ -138,13 +138,6 @@ void _mmap(Nd4jLong* result, size_t length, const char *fileName) {
OffsetType off = 0;
int prot = PROT_READ | PROT_WRITE;
// we need to convert long path (probably) to short pat (actually)
// it's Windows API, in the middle of 2018!
auto sz = GetShortPathName(fileName, nullptr, 0);
auto shortName = new TCHAR[sz];
GetShortPathName(fileName, shortName, sz);
#ifdef _MSC_VER
#pragma warning(push)
#pragma warning(disable: 4293)
@ -170,8 +163,6 @@ void _mmap(Nd4jLong* result, size_t length, const char *fileName) {
h = CreateFileA(fileName, GENERIC_READ | GENERIC_WRITE, FILE_SHARE_WRITE | FILE_SHARE_READ, nullptr, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, nullptr);
delete[] shortName;
if (h == INVALID_HANDLE_VALUE) {
errno = __map_mman_error(GetLastError(), EPERM);
nd4j_printf("Error code: %i\n", (int) errno);

View File

@ -297,8 +297,12 @@ public abstract class Nd4jWorkspace implements MemoryWorkspace {
}
protected void init() {
// in case of MMAP we don't want any learning applied
if (workspaceConfiguration.getPolicyLocation() == LocationPolicy.MMAP && workspaceConfiguration.getPolicyLearning() != LearningPolicy.NONE)
throw new IllegalArgumentException("Workspace backed by memory-mapped file can't have LearningPolicy defined");
// we don't want overallocation in case of MMAP
if (currentSize.get() > 0 && workspaceConfiguration.getPolicyLocation() != LocationPolicy.MMAP) {
if (currentSize.get() > 0 && workspaceConfiguration.getPolicyLocation() != LocationPolicy.MMAP) {
if (!isOver.get()) {
if (workspaceConfiguration.getPolicyAllocation() == AllocationPolicy.OVERALLOCATE
&& workspaceConfiguration.getOverallocationLimit() > 0) {
@ -309,7 +313,6 @@ public abstract class Nd4jWorkspace implements MemoryWorkspace {
if (workspaceConfiguration.getMaxSize() > 0 && currentSize.get() > workspaceConfiguration.getMaxSize())
currentSize.set(workspaceConfiguration.getMaxSize());
}
}

View File

@ -159,7 +159,7 @@ public abstract class BasicWorkspaceManager implements MemoryWorkspaceManager {
if (workspace == null || workspace instanceof DummyWorkspace)
return;
//workspace.destroyWorkspace();
workspace.destroyWorkspace(true);
backingMap.get().remove(workspace.getId());
}

View File

@ -138,6 +138,13 @@ public class CpuWorkspace extends Nd4jWorkspace implements Deallocatable {
}
}
protected long mappedFileSize() {
if (workspaceConfiguration.getPolicyLocation() != LocationPolicy.MMAP)
return 0;
return tempFile.length();
}
@Override
protected void clearExternalAllocations() {
if (isDebug.get())

View File

@ -18,11 +18,15 @@ package org.nd4j.linalg.cpu.nativecpu.workspace;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.memory.Deallocator;
import org.nd4j.linalg.api.memory.enums.LocationPolicy;
import org.nd4j.linalg.api.memory.enums.MemoryKind;
import org.nd4j.linalg.api.memory.pointers.PointersPair;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.nativeblas.NativeOpsHolder;
import java.util.List;
import java.util.Queue;
@ -37,12 +41,16 @@ public class CpuWorkspaceDeallocator implements Deallocator {
private Queue<PointersPair> pinnedPointers;
private List<PointersPair> externalPointers;
private LocationPolicy location;
private Pair<LongPointer, Long> mmapInfo;
public CpuWorkspaceDeallocator(@NonNull CpuWorkspace workspace) {
this.pointersPair = workspace.workspace();
this.pinnedPointers = workspace.pinnedPointers();
this.externalPointers = workspace.externalPointers();
this.location = workspace.getWorkspaceConfiguration().getPolicyLocation();
if (workspace.mappedFileSize() > 0)
this.mmapInfo = Pair.makePair(workspace.mmap, workspace.mappedFileSize());
}
@Override
@ -50,7 +58,7 @@ public class CpuWorkspaceDeallocator implements Deallocator {
log.trace("Deallocating CPU workspace");
// purging workspace planes
if (pointersPair != null) {
if (pointersPair != null && (pointersPair.getDevicePointer() != null || pointersPair.getHostPointer() != null)) {
if (pointersPair.getDevicePointer() != null) {
Nd4j.getMemoryManager().release(pointersPair.getDevicePointer(), MemoryKind.DEVICE);
}
@ -58,6 +66,8 @@ public class CpuWorkspaceDeallocator implements Deallocator {
if (pointersPair.getHostPointer() != null) {
if (location != LocationPolicy.MMAP)
Nd4j.getMemoryManager().release(pointersPair.getHostPointer(), MemoryKind.HOST);
else
NativeOpsHolder.getInstance().getDeviceNativeOps().munmapFile(null, mmapInfo.getFirst(), mmapInfo.getSecond());
}
}

View File

@ -930,6 +930,7 @@ public class BasicWorkspaceTests extends BaseNd4jTest {
WorkspaceConfiguration mmap = WorkspaceConfiguration.builder()
.initialSize(1000000)
.policyLocation(LocationPolicy.MMAP)
.policyLearning(LearningPolicy.NONE)
.build();
MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(mmap, "M2");

View File

@ -344,6 +344,7 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
.initialSize(200 * 1024L * 1024L) // 200mbs
.tempFilePath(tmpFile.toAbsolutePath().toString())
.policyLocation(LocationPolicy.MMAP)
.policyLearning(LearningPolicy.NONE)
.build();
try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(mmap, "M2")) {
@ -373,6 +374,7 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
.initialSize(200 * 1024L * 1024L) // 200mbs
.tempFilePath(tmpFile.toAbsolutePath().toString())
.policyLocation(LocationPolicy.MMAP)
.policyLearning(LearningPolicy.NONE)
.build();
try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(mmap, "M2")) {
@ -380,6 +382,49 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
}
}
@Test
public void testDeleteMappedFile_1() throws Exception {
if (!Nd4j.getEnvironment().isCPU())
return;
val tmpFile = Files.createTempFile("some", "file");
val mmap = WorkspaceConfiguration.builder()
.initialSize(200 * 1024L * 1024L) // 200mbs
.tempFilePath(tmpFile.toAbsolutePath().toString())
.policyLocation(LocationPolicy.MMAP)
.policyLearning(LearningPolicy.NONE)
.build();
try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(mmap, "M2")) {
val x = Nd4j.rand(DataType.FLOAT, 1024);
}
Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
Files.delete(tmpFile);
}
@Test(expected = IllegalArgumentException.class)
public void testDeleteMappedFile_2() throws Exception {
if (!Nd4j.getEnvironment().isCPU())
throw new IllegalArgumentException("Don't try to run on CUDA");
val tmpFile = Files.createTempFile("some", "file");
val mmap = WorkspaceConfiguration.builder()
.initialSize(200 * 1024L * 1024L) // 200mbs
.tempFilePath(tmpFile.toAbsolutePath().toString())
.policyLocation(LocationPolicy.MMAP)
.build();
try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(mmap, "M2")) {
val x = Nd4j.rand(DataType.FLOAT, 1024);
}
Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
Files.delete(tmpFile);
}
@Override
public char ordering() {
return 'c';