fix overallocation for memory-mapped workspaces

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2020-04-22 12:38:22 +03:00
parent 550e84ef43
commit 78260efe54
3 changed files with 26 additions and 10 deletions

View File

@ -297,9 +297,8 @@ public abstract class Nd4jWorkspace implements MemoryWorkspace {
} }
protected void init() { protected void init() {
// we want params validation here // we don't want overallocation in case of MMAP
if (currentSize.get() > 0 && workspaceConfiguration.getPolicyLocation() != LocationPolicy.MMAP) {
if (currentSize.get() > 0) {
if (!isOver.get()) { if (!isOver.get()) {
if (workspaceConfiguration.getPolicyAllocation() == AllocationPolicy.OVERALLOCATE if (workspaceConfiguration.getPolicyAllocation() == AllocationPolicy.OVERALLOCATE
&& workspaceConfiguration.getOverallocationLimit() > 0) { && workspaceConfiguration.getOverallocationLimit() > 0) {

View File

@ -92,11 +92,9 @@ public class CpuWorkspace extends Nd4jWorkspace implements Deallocatable {
if (currentSize.get() > 0) { if (currentSize.get() > 0) {
isInit.set(true); isInit.set(true);
if (isDebug.get())
//if (isDebug.get())
log.info("Allocating [{}] workspace of {} bytes...", id, currentSize.get()); log.info("Allocating [{}] workspace of {} bytes...", id, currentSize.get());
workspace.setHostPointer(new PagedPointer(memoryManager.allocate(currentSize.get() + SAFETY_OFFSET, MemoryKind.HOST, true))); workspace.setHostPointer(new PagedPointer(memoryManager.allocate(currentSize.get() + SAFETY_OFFSET, MemoryKind.HOST, true)));
AllocationsTracker.getInstance().markAllocated(AllocationKind.WORKSPACE, 0, currentSize.get() + SAFETY_OFFSET); AllocationsTracker.getInstance().markAllocated(AllocationKind.WORKSPACE, 0, currentSize.get() + SAFETY_OFFSET);
} }

View File

@ -26,16 +26,14 @@ import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy; import org.nd4j.linalg.api.memory.enums.*;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.api.memory.abstracts.Nd4jWorkspace; import org.nd4j.linalg.api.memory.abstracts.Nd4jWorkspace;
import java.nio.file.Files;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -335,6 +333,27 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
assertEquals(exp, res); assertEquals(exp, res);
} }
@Test
public void testMmapedWorkspaceLimits_1() throws Exception {
val tmpFile = Files.createTempFile("some", "file");
val mmap = WorkspaceConfiguration.builder()
.initialSize(200 * 1024L * 1024L) // 200mbs
.tempFilePath(tmpFile.toAbsolutePath().toString())
.policyLocation(LocationPolicy.MMAP)
.build();
try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(mmap, "M2")) {
int twoHundredMbsOfFloats = 52_428_800; // 200mbs % 4
val addMoreFloats = true;
if (addMoreFloats) {
twoHundredMbsOfFloats += 1_000;
}
val x = Nd4j.rand(DataType.FLOAT, twoHundredMbsOfFloats);
}
}
@Override @Override
public char ordering() { public char ordering() {
return 'c'; return 'c';