From 78260efe54b331a74e68db96cf0f46609a257dd6 Mon Sep 17 00:00:00 2001 From: raver119 Date: Wed, 22 Apr 2020 12:38:22 +0300 Subject: [PATCH] fix overallocation for memory-mapped workspaces Signed-off-by: raver119 --- .../api/memory/abstracts/Nd4jWorkspace.java | 5 ++-- .../cpu/nativecpu/workspace/CpuWorkspace.java | 4 +-- .../workspace/SpecialWorkspaceTests.java | 27 ++++++++++++++++--- 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/memory/abstracts/Nd4jWorkspace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/memory/abstracts/Nd4jWorkspace.java index 823b67bbe..9fa23bf45 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/memory/abstracts/Nd4jWorkspace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/memory/abstracts/Nd4jWorkspace.java @@ -297,9 +297,8 @@ public abstract class Nd4jWorkspace implements MemoryWorkspace { } protected void init() { - // we want params validation here - - if (currentSize.get() > 0) { + // we don't want overallocation in case of MMAP + if (currentSize.get() > 0 && workspaceConfiguration.getPolicyLocation() != LocationPolicy.MMAP) { if (!isOver.get()) { if (workspaceConfiguration.getPolicyAllocation() == AllocationPolicy.OVERALLOCATE && workspaceConfiguration.getOverallocationLimit() > 0) { diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/workspace/CpuWorkspace.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/workspace/CpuWorkspace.java index 0f83c0445..ae21a82cb 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/workspace/CpuWorkspace.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/workspace/CpuWorkspace.java @@ -92,11 +92,9 @@ public class CpuWorkspace extends Nd4jWorkspace implements Deallocatable { if (currentSize.get() > 0) { isInit.set(true); - - //if (isDebug.get()) + if (isDebug.get()) log.info("Allocating [{}] workspace of {} bytes...", id, currentSize.get()); - workspace.setHostPointer(new PagedPointer(memoryManager.allocate(currentSize.get() + SAFETY_OFFSET, MemoryKind.HOST, true))); AllocationsTracker.getInstance().markAllocated(AllocationKind.WORKSPACE, 0, currentSize.get() + SAFETY_OFFSET); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java index 8d7ebb040..391efa475 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java @@ -26,16 +26,14 @@ import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; -import org.nd4j.linalg.api.memory.enums.AllocationPolicy; -import org.nd4j.linalg.api.memory.enums.LearningPolicy; -import org.nd4j.linalg.api.memory.enums.ResetPolicy; -import org.nd4j.linalg.api.memory.enums.SpillPolicy; +import org.nd4j.linalg.api.memory.enums.*; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.api.memory.abstracts.Nd4jWorkspace; +import java.nio.file.Files; import java.util.ArrayList; import java.util.Arrays; @@ -335,6 +333,27 @@ public class SpecialWorkspaceTests extends BaseNd4jTest { assertEquals(exp, res); } + + @Test + public void testMmapedWorkspaceLimits_1() throws Exception { + val tmpFile = Files.createTempFile("some", "file"); + val mmap = WorkspaceConfiguration.builder() + .initialSize(200 * 1024L * 1024L) // 200mbs + .tempFilePath(tmpFile.toAbsolutePath().toString()) + .policyLocation(LocationPolicy.MMAP) + .build(); + + try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(mmap, "M2")) { + int twoHundredMbsOfFloats = 52_428_800; // 200mbs % 4 + val addMoreFloats = true; + if (addMoreFloats) { + twoHundredMbsOfFloats += 1_000; + } + + val x = Nd4j.rand(DataType.FLOAT, twoHundredMbsOfFloats); + } + } + @Override public char ordering() { return 'c';