fix overallocation for memory-mapped workspaces
Signed-off-by: raver119 <raver119@gmail.com>master
parent
550e84ef43
commit
78260efe54
|
@ -297,9 +297,8 @@ public abstract class Nd4jWorkspace implements MemoryWorkspace {
|
|||
}
|
||||
|
||||
protected void init() {
|
||||
// we want params validation here
|
||||
|
||||
if (currentSize.get() > 0) {
|
||||
// we don't want overallocation in case of MMAP
|
||||
if (currentSize.get() > 0 && workspaceConfiguration.getPolicyLocation() != LocationPolicy.MMAP) {
|
||||
if (!isOver.get()) {
|
||||
if (workspaceConfiguration.getPolicyAllocation() == AllocationPolicy.OVERALLOCATE
|
||||
&& workspaceConfiguration.getOverallocationLimit() > 0) {
|
||||
|
|
|
@ -92,11 +92,9 @@ public class CpuWorkspace extends Nd4jWorkspace implements Deallocatable {
|
|||
if (currentSize.get() > 0) {
|
||||
isInit.set(true);
|
||||
|
||||
|
||||
//if (isDebug.get())
|
||||
if (isDebug.get())
|
||||
log.info("Allocating [{}] workspace of {} bytes...", id, currentSize.get());
|
||||
|
||||
|
||||
workspace.setHostPointer(new PagedPointer(memoryManager.allocate(currentSize.get() + SAFETY_OFFSET, MemoryKind.HOST, true)));
|
||||
AllocationsTracker.getInstance().markAllocated(AllocationKind.WORKSPACE, 0, currentSize.get() + SAFETY_OFFSET);
|
||||
}
|
||||
|
|
|
@ -26,16 +26,14 @@ import org.nd4j.linalg.BaseNd4jTest;
|
|||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
||||
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
|
||||
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
|
||||
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
|
||||
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
|
||||
import org.nd4j.linalg.api.memory.enums.*;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
import org.nd4j.linalg.api.memory.abstracts.Nd4jWorkspace;
|
||||
|
||||
import java.nio.file.Files;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
|
||||
|
@ -335,6 +333,27 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
|
|||
|
||||
assertEquals(exp, res);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMmapedWorkspaceLimits_1() throws Exception {
|
||||
val tmpFile = Files.createTempFile("some", "file");
|
||||
val mmap = WorkspaceConfiguration.builder()
|
||||
.initialSize(200 * 1024L * 1024L) // 200mbs
|
||||
.tempFilePath(tmpFile.toAbsolutePath().toString())
|
||||
.policyLocation(LocationPolicy.MMAP)
|
||||
.build();
|
||||
|
||||
try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(mmap, "M2")) {
|
||||
int twoHundredMbsOfFloats = 52_428_800; // 200mbs % 4
|
||||
val addMoreFloats = true;
|
||||
if (addMoreFloats) {
|
||||
twoHundredMbsOfFloats += 1_000;
|
||||
}
|
||||
|
||||
val x = Nd4j.rand(DataType.FLOAT, twoHundredMbsOfFloats);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
|
|
Loading…
Reference in New Issue