[WIP] Memory limits (#167)

* initial commit

Signed-off-by: raver119 <raver119@gmail.com>

* one more initial commit

Signed-off-by: raver119 <raver119@gmail.com>

* additional initial commit

Signed-off-by: raver119 <raver119@gmail.com>

* subsequent initial commit

Signed-off-by: raver119 <raver119@gmail.com>

* initial commit testing

Signed-off-by: raver119 <raver119@gmail.com>

* initial commit per device

Signed-off-by: raver119 <raver119@gmail.com>

* initial commit per group

Signed-off-by: raver119 <raver119@gmail.com>

* initial commit for cuda

Signed-off-by: raver119 <raver119@gmail.com>

* initial commit for cuda + few missed lines

Signed-off-by: raver119 <raver119@gmail.com>

* initial commit for cuda + missed includes

Signed-off-by: raver119 <raver119@gmail.com>

* initial commit for cuda + one more missed include

Signed-off-by: raver119 <raver119@gmail.com>

* initial commit shouldn't count host mem as dev0 in cuda

Signed-off-by: raver119 <raver119@gmail.com>

* initial commit that tracks HOST group limits for CUDA

Signed-off-by: raver119 <raver119@gmail.com>

* initial commit with some Environment changes

Signed-off-by: raver119 <raver119@gmail.com>

* initial commit with more Environment changes

Signed-off-by: raver119 <raver119@gmail.com>

* initial commit with maxMasterThreads fix

Signed-off-by: raver119 <raver119@gmail.com>

* initial commit with maxMasterThreads fix

Signed-off-by: raver119 <raver119@gmail.com>

* initial commit without maxMasterThreads exception

Signed-off-by: raver119 <raver119@gmail.com>

* initial commit without Nd4jULong in Environment

Signed-off-by: raver119 <raver119@gmail.com>

* add sleep and more iterations for OOM cases

Signed-off-by: raver119 <raver119@gmail.com>

* limits propagation from java side

Signed-off-by: raver119 <raver119@gmail.com>

* - consume ErrorCode every time
- one test for memory limits

Signed-off-by: raver119 <raver119@gmail.com>

* unordered_map

Signed-off-by: raver119 <raver119@gmail.com>

* unordered_map

Signed-off-by: raver119 <raver119@gmail.com>

* unordered_map

Signed-off-by: raver119 <raver119@gmail.com>

* RSub op mapping fixed

Signed-off-by: raver119 <raver119@gmail.com>

* typo fixed

Signed-off-by: raver119 <raver119@gmail.com>

* one bad test fixed

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2020-01-24 10:11:09 +03:00 committed by GitHub
parent 0caf50f80f
commit 5d69069177
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
37 changed files with 844 additions and 46 deletions

View File

@ -26,6 +26,7 @@
#include <helpers/StringUtils.h>
#include <thread>
#include <helpers/logger.h>
#include <memory/MemoryCounter.h>
#ifdef _OPENMP
@ -291,11 +292,19 @@ namespace nd4j {
}
void Environment::setMaxThreads(int max) {
// FIXME: not possible at this moment, since maxThreads is limited by number of threads in pool. however we can allocate more threads if we want
//_maxThreads.store(max);
}
void Environment::setMaxMasterThreads(int max) {
//_maxMasterThreads = max;
if (max > maxThreads()) {
max = maxThreads();
}
if (max < 1)
return;
_maxMasterThreads = max;
}
bool Environment::precisionBoostAllowed() {
@ -334,6 +343,38 @@ namespace nd4j {
_allowHelpers.store(reallyAllow);
}
void Environment::setGroupLimit(int group, Nd4jLong numBytes) {
nd4j::memory::MemoryCounter::getInstance()->setGroupLimit((nd4j::memory::MemoryType) group, numBytes);
}
void Environment::setDeviceLimit(int deviceId, Nd4jLong numBytes) {
nd4j::memory::MemoryCounter::getInstance()->setDeviceLimit(deviceId, numBytes);
}
Nd4jLong Environment::getGroupLimit(int group) {
return nd4j::memory::MemoryCounter::getInstance()->groupLimit((nd4j::memory::MemoryType) group);
}
Nd4jLong Environment::getDeviceLimit(int deviceId) {
return nd4j::memory::MemoryCounter::getInstance()->deviceLimit(deviceId);
}
Nd4jLong Environment::getGroupCounter(int group) {
return nd4j::memory::MemoryCounter::getInstance()->allocatedGroup((nd4j::memory::MemoryType) group);
}
Nd4jLong Environment::getDeviceCounter(int deviceId) {
return nd4j::memory::MemoryCounter::getInstance()->allocatedDevice(deviceId);
}
uint64_t Environment::maxPrimaryMemory() {
return _maxTotalPrimaryMemory.load();
}
uint64_t Environment::maxSpecialMemory() {
return _maxTotalSpecialMemory.load();
}
nd4j::Environment *nd4j::Environment::_instance = 0;
}

View File

@ -27,6 +27,7 @@
#include <stdexcept>
#include <array/DataType.h>
#include <types/pair.h>
#include <pointercast.h>
namespace nd4j{
class ND4J_EXPORT Environment {
@ -97,10 +98,30 @@ namespace nd4j{
int maxMasterThreads();
void setMaxMasterThreads(int max);
/*
* Legacy memory limits API, still used in new API as simplified version
*/
void setMaxPrimaryMemory(uint64_t maxBytes);
void setMaxSpecialyMemory(uint64_t maxBytes);
void setMaxDeviceMemory(uint64_t maxBytes);
uint64_t maxPrimaryMemory();
uint64_t maxSpecialMemory();
////////////////////////
/*
* Methods for memory limits/counters
*/
void setGroupLimit(int group, Nd4jLong numBytes);
void setDeviceLimit(int deviceId, Nd4jLong numBytes);
Nd4jLong getGroupLimit(int group);
Nd4jLong getDeviceLimit(int deviceId);
Nd4jLong getGroupCounter(int group);
Nd4jLong getDeviceCounter(int deviceId);
////////////////////////
bool isUseMKLDNN() { return _useMKLDNN.load(); }
void setUseMKLDNN(bool useMKLDNN) { _useMKLDNN.store(useMKLDNN); }

View File

@ -44,6 +44,7 @@
#include <execution/AffinityManager.h>
#include <memory>
#include <array/InteropDataBuffer.h>
#include <memory/MemoryCounter.h>
namespace nd4j {

View File

@ -76,6 +76,7 @@ bool verbose = false;
#include <graph/execution/LogicExecutor.h>
#include <graph/ResultWrapper.h>
#include <DebugInfo.h>
#include <memory/MemoryCounter.h>
typedef nd4j::InteropDataBuffer OpaqueDataBuffer;

View File

@ -3093,8 +3093,14 @@ bool isOptimalRequirementsMet() {
}
OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth) {
try {
auto dtype = DataTypeUtils::fromInt(dataType);
return new nd4j::InteropDataBuffer(elements * DataTypeUtils::sizeOf(dtype) , dtype, allocateBoth);
} catch (std::exception &e) {
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
return nullptr;
}
}
Nd4jPointer dbPrimaryBuffer(OpaqueDataBuffer *dataBuffer) {
@ -3126,7 +3132,12 @@ void dbAllocateSpecialBuffer(OpaqueDataBuffer *dataBuffer) {
}
void dbExpandBuffer(OpaqueDataBuffer *dataBuffer, Nd4jLong elements) {
try {
dataBuffer->dataBuffer()->expand(elements * DataTypeUtils::sizeOf(dataBuffer->dataBuffer()->getDataType()));
} catch (std::exception &e) {
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
}
}
OpaqueDataBuffer* dbCreateView(OpaqueDataBuffer *dataBuffer, Nd4jLong length, Nd4jLong offset) {

View File

@ -3781,8 +3781,14 @@ void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow) {
}
OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth) {
try {
auto dtype = DataTypeUtils::fromInt(dataType);
return new nd4j::InteropDataBuffer(elements * DataTypeUtils::sizeOf(dtype) , dtype, allocateBoth);
return new nd4j::InteropDataBuffer(elements * DataTypeUtils::sizeOf(dtype), dtype, allocateBoth);
} catch (std::exception &e) {
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
return nullptr;
}
}
Nd4jPointer dbPrimaryBuffer(OpaqueDataBuffer *dataBuffer) {
@ -3814,7 +3820,12 @@ void dbAllocateSpecialBuffer(OpaqueDataBuffer *dataBuffer) {
}
void dbExpandBuffer(OpaqueDataBuffer *dataBuffer, Nd4jLong elements) {
try {
dataBuffer->dataBuffer()->expand(elements * DataTypeUtils::sizeOf(dataBuffer->dataBuffer()->getDataType()));
} catch (std::exception &e) {
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
}
}
OpaqueDataBuffer* dbCreateView(OpaqueDataBuffer *dataBuffer, Nd4jLong length, Nd4jLong offset) {

View File

@ -22,7 +22,7 @@
#define DEV_TESTS_CONSTANTDESCRIPTOR_H
#include <array/DataType.h>
#include <map>
#include <unordered_map>
#include <vector>
#include <pointercast.h>
#include <dll.h>

View File

@ -25,7 +25,7 @@
#include <string>
#include <atomic>
#include <map>
#include <unordered_map>
#include <NDArray.h>
#include <memory/Workspace.h>
#include <dll.h>

View File

@ -21,7 +21,7 @@
#ifndef DEV_TESTS_SHAPEDESCRIPTOR_H
#define DEV_TESTS_SHAPEDESCRIPTOR_H
#include <map>
#include <unordered_map>
#include <vector>
#include <dll.h>
#include <pointercast.h>

View File

@ -23,6 +23,9 @@
#include <DataTypeUtils.h>
#include <op_boilerplate.h>
#include <exceptions/cuda_exception.h>
#include <execution/AffinityManager.h>
#include <memory/MemoryCounter.h>
#include <exceptions/allocation_exception.h>
namespace nd4j {
void DataBuffer::expand(const uint64_t size) {
@ -64,8 +67,20 @@ namespace nd4j {
void DataBuffer::allocateSpecial() {
if (_specialBuffer == nullptr && getLenInBytes() > 0) {
auto deviceId = nd4j::AffinityManager::currentDeviceId();
if (_workspace == nullptr)
if (!nd4j::memory::MemoryCounter::getInstance()->validate(getLenInBytes()))
throw nd4j::allocation_exception::build("Requested amount exceeds device limits", nd4j::memory::MemoryCounter::getInstance()->deviceLimit(deviceId), getLenInBytes());
ALLOCATE_SPECIAL(_specialBuffer, _workspace, getLenInBytes(), int8_t);
_isOwnerSpecial = true;
if (_workspace == nullptr) {
nd4j::memory::MemoryCounter::getInstance()->countIn(deviceId, getLenInBytes());
nd4j::memory::MemoryCounter::getInstance()->countIn(nd4j::memory::MemoryType::DEVICE, getLenInBytes());
}
}
}

View File

@ -23,6 +23,8 @@
#include <helpers/logger.h>
#include <array/DataTypeUtils.h>
#include <execution/AffinityManager.h>
#include <memory/MemoryCounter.h>
#include <exceptions/allocation_exception.h>
namespace nd4j {
///// IMLEMENTATION OF COMMON METHODS /////
@ -232,14 +234,35 @@ namespace nd4j {
void DataBuffer::allocatePrimary() {
if (_primaryBuffer == nullptr && getLenInBytes() > 0) {
auto deviceId = nd4j::AffinityManager::currentDeviceId();
// check if this allocation won't bring us above limit
if (_workspace == nullptr) {
if (Environment::getInstance()->isCPU()) {
// on cpu backend we validate against device 0 for now
if (!nd4j::memory::MemoryCounter::getInstance()->validate(getLenInBytes()))
throw nd4j::allocation_exception::build("Requested amount exceeds HOST device limits", nd4j::memory::MemoryCounter::getInstance()->deviceLimit(deviceId), getLenInBytes());
} else {
// in heterogenous mode we valdate against device group
if (!nd4j::memory::MemoryCounter::getInstance()->validateGroup(nd4j::memory::MemoryType::HOST, getLenInBytes()))
throw nd4j::allocation_exception::build("Requested amount exceeds HOST group limits", nd4j::memory::MemoryCounter::getInstance()->groupLimit(nd4j::memory::MemoryType::HOST), getLenInBytes());
}
}
ALLOCATE(_primaryBuffer, _workspace, getLenInBytes(), int8_t);
_isOwnerPrimary = true;
// count in towards current deviceId if we're not in workspace mode
if (_workspace == nullptr) {
if (Environment::getInstance()->isCPU()) // we don't want this counter to be added to CUDA device
nd4j::memory::MemoryCounter::getInstance()->countIn(deviceId, getLenInBytes());
nd4j::memory::MemoryCounter::getInstance()->countIn(nd4j::memory::MemoryType::HOST, getLenInBytes());
}
}
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::setAllocFlags(const bool isOwnerPrimary, const bool isOwnerSpecial) {
_isOwnerPrimary = isOwnerPrimary;
_isOwnerSpecial = isOwnerSpecial;
}
@ -252,6 +275,15 @@ namespace nd4j {
RELEASE(p, _workspace);
_primaryBuffer = nullptr;
_isOwnerPrimary = false;
// count out towards DataBuffer device, only if we're not in workspace
if (_workspace == nullptr) {
if (Environment::getInstance()->isCPU())
nd4j::memory::MemoryCounter::getInstance()->countOut(_deviceId, getLenInBytes());
nd4j::memory::MemoryCounter::getInstance()->countOut(nd4j::memory::MemoryType::HOST, getLenInBytes());
}
}
}

View File

@ -40,6 +40,7 @@ namespace nd4j {
~allocation_exception() = default;
static allocation_exception build(std::string message, Nd4jLong bytes);
static allocation_exception build(std::string message, Nd4jLong limit, Nd4jLong bytes);
};
}

View File

@ -31,4 +31,11 @@ namespace nd4j {
message += "; Requested bytes: [" + bytes + "]";
return allocation_exception(message);
}
allocation_exception allocation_exception::build(std::string message, Nd4jLong limit, Nd4jLong numBytes) {
auto bytes = StringUtils::valueToString<Nd4jLong>(numBytes);
auto lim = StringUtils::valueToString<Nd4jLong>(limit);
message += "; Limit bytes: [" + lim + "]; Requested bytes: [" + bytes + "]";
return allocation_exception(message);
}
}

View File

@ -23,7 +23,7 @@
#include <vector>
#include <initializer_list>
#include <map>
#include <unordered_map>
#include <string>
#include <flatbuffers/flatbuffers.h>
#include <graph/Variable.h>

View File

@ -23,7 +23,7 @@
#include <list>
#include <algorithm>
#include <map>
#include <unordered_map>
//#include <NDArray.h>
#include <graph/Node.h>
#include <graph/Stash.h>

View File

@ -20,7 +20,7 @@
#include <helpers/logger.h>
#include <pointercast.h>
#include <map>
#include <unordered_map>
#include <graph/Graph.h>
#include <helpers/SimpleReadWriteLock.h>
#include <exceptions/unknown_graph_exception.h>

View File

@ -25,7 +25,7 @@
#include <op_boilerplate.h>
#include <dll.h>
#include <vector>
#include <map>
#include <unordered_map>
#include <graph/Scope.h>
#include <Status.h>
#include <graph/VariableSpace.h>

View File

@ -22,7 +22,7 @@
#define LIBND4J_SCOPE_H
#include <string>
#include <map>
#include <unordered_map>
#include <graph/Node.h>
namespace nd4j {

View File

@ -23,7 +23,7 @@
//#include <graph/Block.h>
#include <NDArray.h>
#include <map>
#include <unordered_map>
#include <string>
#include <atomic>
#include <pointercast.h>

View File

@ -26,7 +26,7 @@
#include <string>
#include <vector>
#include <list>
#include <map>
#include <unordered_map>
#include <mutex>
#include <NDArray.h>
#include <array/NDArrayList.h>

View File

@ -0,0 +1,146 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef SD_MEMORYCOUNTER_H
#define SD_MEMORYCOUNTER_H
#include <pointercast.h>
#include <dll.h>
#include <map>
#include <memory/MemoryType.h>
#include <mutex>
namespace nd4j {
namespace memory {
/**
* This class provides simple per-device counter
*/
class ND4J_EXPORT MemoryCounter {
private:
static MemoryCounter* _INSTANCE;
// used for synchronization
std::mutex _locker;
// per-device counters
std::map<int, Nd4jLong> _deviceCounters;
// TODO: change this wrt heterogenous stuff on next iteration
// per-group counters
std::map<nd4j::memory::MemoryType, Nd4jLong> _groupCounters;
// per-device limits
std::map<int, Nd4jLong> _deviceLimits;
// per-group limits
std::map<nd4j::memory::MemoryType, Nd4jLong> _groupLimits;
MemoryCounter();
~MemoryCounter() = default;
public:
static MemoryCounter *getInstance();
/**
* This method checks if allocation of numBytes won't break through per-group or per-device limit
* @param numBytes
* @return TRUE if allocated ammount will keep us below limit, FALSE otherwise
*/
bool validate(Nd4jLong numBytes);
/**
* This method checks if allocation of numBytes won't break through per-device limit
* @param deviceId
* @param numBytes
* @return TRUE if allocated ammount will keep us below limit, FALSE otherwise
*/
bool validateDevice(int deviceId, Nd4jLong numBytes);
/**
* This method checks if allocation of numBytes won't break through per-group limit
* @param deviceId
* @param numBytes
* @return TRUE if allocated ammount will keep us below limit, FALSE otherwise
*/
bool validateGroup(nd4j::memory::MemoryType group, Nd4jLong numBytes);
/**
* This method adds specified number of bytes to specified counter
* @param deviceId
* @param numBytes
*/
void countIn(int deviceId, Nd4jLong numBytes);
void countIn(nd4j::memory::MemoryType group, Nd4jLong numBytes);
/**
* This method subtracts specified number of bytes from specified counter
* @param deviceId
* @param numBytes
*/
void countOut(int deviceId, Nd4jLong numBytes);
void countOut(nd4j::memory::MemoryType group, Nd4jLong numBytes);
/**
* This method returns amount of memory allocated on specified device
* @param deviceId
* @return
*/
Nd4jLong allocatedDevice(int deviceId);
/**
* This method returns amount of memory allocated in specified group of devices
* @param group
* @return
*/
Nd4jLong allocatedGroup(nd4j::memory::MemoryType group);
/**
* This method allows to set per-device memory limits
* @param deviceId
* @param numBytes
*/
void setDeviceLimit(int deviceId, Nd4jLong numBytes);
/**
* This method returns current device limit in bytes
* @param deviceId
* @return
*/
Nd4jLong deviceLimit(int deviceId);
/**
* This method allows to set per-group memory limits
* @param group
* @param numBytes
*/
void setGroupLimit(nd4j::memory::MemoryType group, Nd4jLong numBytes);
/**
* This method returns current group limit in bytes
* @param group
* @return
*/
Nd4jLong groupLimit(nd4j::memory::MemoryType group);
};
}
}
#endif //SD_MEMORYCOUNTER_H

View File

@ -30,6 +30,9 @@
namespace nd4j {
namespace memory {
/**
* This class is used for tracking memory allocation wrt their allocation points in code
*/
class ND4J_EXPORT MemoryTracker {
private:
static MemoryTracker* _INSTANCE;

View File

@ -0,0 +1,133 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "../MemoryCounter.h"
#include <execution/AffinityManager.h>
#include <Environment.h>
#include <helpers/logger.h>
namespace nd4j {
namespace memory {
MemoryCounter::MemoryCounter() {
auto numDevices = nd4j::AffinityManager::numberOfDevices();
// setting default 0s
for (int e = 0; e < numDevices; e++) {
_deviceLimits[e] = 0;
_deviceCounters[e] = 0;
}
// setting initial values for limits
_groupLimits[nd4j::memory::MemoryType::HOST] = nd4j::Environment::getInstance()->maxPrimaryMemory();
_groupLimits[nd4j::memory::MemoryType::DEVICE] = nd4j::Environment::getInstance()->maxSpecialMemory();
// setting initial counter values
_groupCounters[nd4j::memory::MemoryType::HOST] = 0;
_groupCounters[nd4j::memory::MemoryType::DEVICE] = 0;
}
MemoryCounter* MemoryCounter::getInstance() {
if (_INSTANCE == 0)
_INSTANCE = new MemoryCounter();
return _INSTANCE;
}
void MemoryCounter::countIn(int deviceId, Nd4jLong numBytes) {
std::lock_guard<std::mutex> lock(_locker);
_deviceCounters[deviceId] += numBytes;
}
void MemoryCounter::countIn(nd4j::memory::MemoryType group, Nd4jLong numBytes) {
std::lock_guard<std::mutex> lock(_locker);
_groupCounters[group] += numBytes;
}
void MemoryCounter::countOut(int deviceId, Nd4jLong numBytes) {
std::lock_guard<std::mutex> lock(_locker);
_deviceCounters[deviceId] -= numBytes;
}
void MemoryCounter::countOut(nd4j::memory::MemoryType group, Nd4jLong numBytes) {
std::lock_guard<std::mutex> lock(_locker);
_groupCounters[group] -= numBytes;
}
bool MemoryCounter::validate(Nd4jLong numBytes) {
auto deviceId = nd4j::AffinityManager::currentDeviceId();
return validateDevice(deviceId, numBytes);
}
bool MemoryCounter::validateDevice(int deviceId, Nd4jLong numBytes) {
std::lock_guard<std::mutex> lock(_locker);
auto dLimit = _deviceLimits[deviceId];
if (dLimit <= 0)
return true;
auto dAlloc = _deviceCounters[deviceId];
return numBytes + dAlloc <= dLimit;
}
bool MemoryCounter::validateGroup(nd4j::memory::MemoryType group, Nd4jLong numBytes) {
std::lock_guard<std::mutex> lock(_locker);
auto gLimit = _groupLimits[group];
if (gLimit <= 0)
return true;
auto gAlloc = _groupCounters[group];
return numBytes + gAlloc <= gLimit;
}
Nd4jLong MemoryCounter::allocatedDevice(int deviceId) {
std::lock_guard<std::mutex> lock(_locker);
return _deviceCounters[deviceId];
}
Nd4jLong MemoryCounter::allocatedGroup(nd4j::memory::MemoryType group) {
std::lock_guard<std::mutex> lock(_locker);
return _groupCounters[group];
}
void MemoryCounter::setDeviceLimit(int deviceId, Nd4jLong numBytes) {
std::lock_guard<std::mutex> lock(_locker);
_deviceLimits[deviceId] = numBytes;
}
void MemoryCounter::setGroupLimit(nd4j::memory::MemoryType group, Nd4jLong numBytes) {
std::lock_guard<std::mutex> lock(_locker);
_groupLimits[group] = numBytes;
}
Nd4jLong MemoryCounter::deviceLimit(int deviceId) {
std::lock_guard<std::mutex> lock(_locker);
return _deviceLimits[deviceId];
}
Nd4jLong MemoryCounter::groupLimit(nd4j::memory::MemoryType group) {
std::lock_guard<std::mutex> lock(_locker);
return _groupLimits[group];
}
MemoryCounter* MemoryCounter::_INSTANCE = 0;
}
}

View File

@ -23,7 +23,7 @@
#include <pointercast.h>
#include <vector>
#include <map>
#include <unordered_map>
#include <mutex>
#include <ops/declarable/DeclarableOp.h>
#include <ops/declarable/PlatformHelper.h>

View File

@ -22,7 +22,7 @@
#include <ops/declarable/helpers/segment.h>
#include <ShapeUtils.h>
#include <execution/Threads.h>
#include <map>
#include <unordered_map>
namespace nd4j {
namespace ops {

View File

@ -0,0 +1,78 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "testlayers.h"
#include <NDArray.h>
#include <Context.h>
#include <Node.h>
#include <graph/Variable.h>
#include <graph/VariableSpace.h>
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/convolutions.h>
#include <ops/declarable/helpers/col2im.h>
#include <helpers/RandomLauncher.h>
using namespace nd4j;
using namespace nd4j::graph;
using namespace nd4j::memory;
class DataBufferTests : public testing::Test {
public:
};
TEST_F(DataBufferTests, test_alloc_limit_1) {
if (!Environment::getInstance()->isCPU())
return;
auto deviceId = AffinityManager::currentDeviceId();
auto odLimit = MemoryCounter::getInstance()->deviceLimit(deviceId);
auto ogLimit = MemoryCounter::getInstance()->groupLimit(MemoryType::HOST);
auto odUse = MemoryCounter::getInstance()->allocatedDevice(deviceId);
auto ogUse = MemoryCounter::getInstance()->allocatedGroup(MemoryType::HOST);
auto limitSize = 150 * 1024 * 1024;
auto allocSize = 100000000;
MemoryCounter::getInstance()->setDeviceLimit(deviceId, odLimit + limitSize);
MemoryCounter::getInstance()->setGroupLimit(MemoryType::HOST, odLimit + limitSize);
DataBuffer buffer(allocSize, DataType::INT32);
// separately testing per-device limits and group limits
ASSERT_EQ(odUse + allocSize, MemoryCounter::getInstance()->allocatedDevice(deviceId));
ASSERT_EQ(ogUse + allocSize, MemoryCounter::getInstance()->allocatedGroup(MemoryType::HOST));
// setting smaller limits, to make sure next allocation fails with OOM exception
MemoryCounter::getInstance()->setDeviceLimit(deviceId, allocSize - 100);
MemoryCounter::getInstance()->setGroupLimit(MemoryType::HOST, allocSize - 100);
try {
DataBuffer bufferFailed(allocSize, DataType::INT32);
ASSERT_TRUE(false);
} catch (allocation_exception &e) {
// we expect exception here
}
// restore original limits, so subsequent tests do not fail
MemoryCounter::getInstance()->setDeviceLimit(deviceId, odLimit);
MemoryCounter::getInstance()->setGroupLimit(MemoryType::HOST, odLimit);
}

View File

@ -0,0 +1,87 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "testlayers.h"
#include <NDArray.h>
#include <Context.h>
#include <Node.h>
#include <graph/Variable.h>
#include <graph/VariableSpace.h>
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/convolutions.h>
#include <ops/declarable/helpers/col2im.h>
#include <helpers/RandomLauncher.h>
using namespace nd4j;
using namespace nd4j::graph;
using namespace nd4j::memory;
class DataBufferTestsCuda : public testing::Test {
public:
};
TEST_F(DataBufferTestsCuda, test_alloc_limit_1) {
auto deviceId = AffinityManager::currentDeviceId();
auto odLimit = MemoryCounter::getInstance()->deviceLimit(deviceId);
auto opLimit = MemoryCounter::getInstance()->groupLimit(MemoryType::HOST);
auto osLimit = MemoryCounter::getInstance()->groupLimit(MemoryType::DEVICE);
auto odUse = MemoryCounter::getInstance()->allocatedDevice(deviceId);
auto opUse = MemoryCounter::getInstance()->allocatedGroup(MemoryType::HOST);
auto osUse = MemoryCounter::getInstance()->allocatedGroup(MemoryType::DEVICE);
auto limitSize = 150000000;
auto allocSize = 100000000;
MemoryCounter::getInstance()->setDeviceLimit(deviceId, odLimit + limitSize);
MemoryCounter::getInstance()->setGroupLimit(MemoryType::HOST, opLimit + limitSize);
MemoryCounter::getInstance()->setGroupLimit(MemoryType::DEVICE, osLimit + limitSize);
DataBuffer buffer(allocSize, DataType::INT32, nullptr, true);
// separately testing per-device limits and group limits
ASSERT_EQ(odUse + allocSize, MemoryCounter::getInstance()->allocatedDevice(deviceId));
ASSERT_EQ(opUse + allocSize, MemoryCounter::getInstance()->allocatedGroup(MemoryType::HOST));
ASSERT_EQ(osUse + allocSize, MemoryCounter::getInstance()->allocatedGroup(MemoryType::DEVICE));
// setting smaller limits, to make sure next allocation fails with OOM exception
MemoryCounter::getInstance()->setDeviceLimit(deviceId, allocSize - 100);
MemoryCounter::getInstance()->setGroupLimit(MemoryType::DEVICE, allocSize - 100);
// this allocation should fail, since we're allocating too much
try {
DataBuffer bufferFailed(allocSize, DataType::INT32);
ASSERT_TRUE(false);
} catch (allocation_exception &e) {
// we expect exception here
}
//
// restore original limits, so subsequent tests do not fail
MemoryCounter::getInstance()->setDeviceLimit(deviceId, odLimit);
MemoryCounter::getInstance()->setGroupLimit(MemoryType::HOST, opLimit);
MemoryCounter::getInstance()->setGroupLimit(MemoryType::DEVICE, osLimit);
}

View File

@ -29,8 +29,6 @@
#include <ops/declarable/helpers/col2im.h>
#include <helpers/RandomLauncher.h>
using namespace nd4j;
using namespace nd4j;
using namespace nd4j::graph;

View File

@ -63,7 +63,7 @@ public class RSubOp extends BaseDynamicTransformOp {
@Override
public String tensorflowName() {
return "Sub";
throw new NoOpNameFoundException("No TensorFlow op name found for: " + getClass().getName());
}
public RSubOp( INDArray[] inputs, INDArray[] outputs) {

View File

@ -18,18 +18,23 @@
package org.nd4j.nativeblas;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.buffer.DataType;
import java.util.concurrent.locks.LockSupport;
/**
* This class is a opaque pointer to InteropDataBuffer, used for Java/C++ interop related to INDArray DataBuffer
*
* @author saudet
* @author raver119@gmail.com
*/
@Slf4j
public class OpaqueDataBuffer extends Pointer {
// TODO: make this configurable
private static final int MAX_TRIES = 3;
private static final int MAX_TRIES = 5;
public OpaqueDataBuffer(Pointer p) { super(p); }
@ -53,11 +58,13 @@ public class OpaqueDataBuffer extends Pointer {
// check error code
ec = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode();
if (ec != 0) {
if (em == null)
em = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage();
// if allocation failed it might be caused by casual OOM, so we'll try GC
System.gc();
// sleeping for 50ms
Thread.sleep(50);
} else {
// just return the buffer
return buffer;
@ -89,11 +96,12 @@ public class OpaqueDataBuffer extends Pointer {
// check error code
ec = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode();
if (ec != 0) {
if (em == null)
em = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage();
// if expansion failed it might be caused by casual OOM, so we'll try GC
System.gc();
Thread.sleep(50);
} else {
// just return
return;
@ -126,11 +134,13 @@ public class OpaqueDataBuffer extends Pointer {
// check error code
ec = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode();
if (ec != 0) {
if (em == null)
em = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage();
// if view creation failed it might be caused by casual OOM, so we'll try GC
System.gc();
// sleeping to let gc kick in
Thread.sleep(50);
} else {
// just return
return buffer;

View File

@ -167,4 +167,29 @@ public class CudaEnvironment implements Environment {
public boolean isCPU() {
return e.isCPU();
}
@Override
public void setGroupLimit(int group, long numBytes) {
e.setGroupLimit(group, numBytes);
}
@Override
public void setDeviceLimit(int deviceId, long numBytes) {
e.setDeviceLimit(deviceId, numBytes);
}
@Override
public long getGroupLimit(int group) {
return e.getGroupLimit(group);
}
@Override
public long getDeviceLimit(int deviceId) {
return e.getDeviceLimit(deviceId);
}
@Override
public long getDeviceCouner(int deviceId) {
return e.getDeviceCounter(deviceId);
}
}

View File

@ -495,7 +495,7 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper {
// #define DEV_TESTS_CONSTANTDESCRIPTOR_H
// #include <array/DataType.h>
// #include <map>
// #include <unordered_map>
// #include <vector>
// #include <pointercast.h>
// #include <dll.h>
@ -808,6 +808,7 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper {
// #include <stdexcept>
// #include <array/DataType.h>
// #include <types/pair.h>
// #include <pointercast.h>
@Namespace("nd4j") @NoOffset public static class Environment extends Pointer {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
@ -846,10 +847,30 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper {
public native int maxMasterThreads();
public native void setMaxMasterThreads(int max);
/*
* Legacy memory limits API, still used in new API as simplified version
*/
public native void setMaxPrimaryMemory(@Cast("uint64_t") long maxBytes);
public native void setMaxSpecialyMemory(@Cast("uint64_t") long maxBytes);
public native void setMaxDeviceMemory(@Cast("uint64_t") long maxBytes);
public native @Cast("uint64_t") long maxPrimaryMemory();
public native @Cast("uint64_t") long maxSpecialMemory();
////////////////////////
/*
* Methods for memory limits/counters
*/
public native void setGroupLimit(int group, @Cast("Nd4jLong") long numBytes);
public native void setDeviceLimit(int deviceId, @Cast("Nd4jLong") long numBytes);
public native @Cast("Nd4jLong") long getGroupLimit(int group);
public native @Cast("Nd4jLong") long getDeviceLimit(int deviceId);
public native @Cast("Nd4jLong") long getGroupCounter(int group);
public native @Cast("Nd4jLong") long getDeviceCounter(int deviceId);
////////////////////////
public native @Cast("bool") boolean isUseMKLDNN();
public native void setUseMKLDNN(@Cast("bool") boolean useMKLDNN);
@ -1017,6 +1038,7 @@ bool verbose = false;
// #include <graph/execution/LogicExecutor.h>
// #include <graph/ResultWrapper.h>
// #include <DebugInfo.h>
// #include <memory/MemoryCounter.h>
/**
* This function returns last error code stored,
@ -3591,6 +3613,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
// #include <execution/AffinityManager.h>
// #include <memory>
// #include <array/InteropDataBuffer.h>
// #include <memory/MemoryCounter.h>
@ -4856,7 +4879,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
// #include <string>
// #include <atomic>
// #include <map>
// #include <unordered_map>
// #include <NDArray.h>
// #include <memory/Workspace.h>
// #include <dll.h>
@ -5007,6 +5030,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
// #include <chrono>
// #include <array/DataTypeUtils.h>
// #include <helpers/logger.h>
// #include <stdexcept>
// #ifdef __CUDACC__
// #endif
@ -5458,7 +5482,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
//#include <graph/Block.h>
// #include <NDArray.h>
// #include <map>
// #include <unordered_map>
// #include <string>
// #include <atomic>
// #include <pointercast.h>
@ -5549,7 +5573,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
// #include <op_boilerplate.h>
// #include <dll.h>
// #include <vector>
// #include <map>
// #include <unordered_map>
// #include <graph/Scope.h>
// #include <Status.h>
// #include <graph/VariableSpace.h>
@ -5665,7 +5689,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
// #include <string>
// #include <vector>
// #include <list>
// #include <map>
// #include <unordered_map>
// #include <mutex>
// #include <NDArray.h>
// #include <array/NDArrayList.h>
@ -9674,7 +9698,7 @@ public static final int PREALLOC_SIZE = 33554432;
// #include <pointercast.h>
// #include <vector>
// #include <map>
// #include <unordered_map>
// #include <mutex>
// #include <ops/declarable/DeclarableOp.h>
// #include <ops/declarable/PlatformHelper.h>
@ -9939,7 +9963,7 @@ public static final int PREALLOC_SIZE = 33554432;
// #ifndef DEV_TESTS_SHAPEDESCRIPTOR_H
// #define DEV_TESTS_SHAPEDESCRIPTOR_H
// #include <map>
// #include <unordered_map>
// #include <vector>
// #include <dll.h>
// #include <pointercast.h>

View File

@ -167,4 +167,29 @@ public class CpuEnvironment implements Environment {
public boolean isCPU() {
return e.isCPU();
}
@Override
public void setGroupLimit(int group, long numBytes) {
e.setGroupLimit(group, numBytes);
}
@Override
public void setDeviceLimit(int deviceId, long numBytes) {
e.setDeviceLimit(deviceId, numBytes);
}
@Override
public long getGroupLimit(int group) {
return e.getGroupLimit(group);
}
@Override
public long getDeviceLimit(int deviceId) {
return e.getDeviceLimit(deviceId);
}
@Override
public long getDeviceCouner(int deviceId) {
return e.getDeviceCounter(deviceId);
}
}

View File

@ -573,7 +573,7 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
// #define DEV_TESTS_CONSTANTDESCRIPTOR_H
// #include <array/DataType.h>
// #include <map>
// #include <unordered_map>
// #include <vector>
// #include <pointercast.h>
// #include <dll.h>
@ -811,6 +811,7 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
// #include <stdexcept>
// #include <array/DataType.h>
// #include <types/pair.h>
// #include <pointercast.h>
@Namespace("nd4j") @NoOffset public static class Environment extends Pointer {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
@ -849,10 +850,30 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
public native int maxMasterThreads();
public native void setMaxMasterThreads(int max);
/*
* Legacy memory limits API, still used in new API as simplified version
*/
public native void setMaxPrimaryMemory(@Cast("uint64_t") long maxBytes);
public native void setMaxSpecialyMemory(@Cast("uint64_t") long maxBytes);
public native void setMaxDeviceMemory(@Cast("uint64_t") long maxBytes);
public native @Cast("uint64_t") long maxPrimaryMemory();
public native @Cast("uint64_t") long maxSpecialMemory();
////////////////////////
/*
* Methods for memory limits/counters
*/
public native void setGroupLimit(int group, @Cast("Nd4jLong") long numBytes);
public native void setDeviceLimit(int deviceId, @Cast("Nd4jLong") long numBytes);
public native @Cast("Nd4jLong") long getGroupLimit(int group);
public native @Cast("Nd4jLong") long getDeviceLimit(int deviceId);
public native @Cast("Nd4jLong") long getGroupCounter(int group);
public native @Cast("Nd4jLong") long getDeviceCounter(int deviceId);
////////////////////////
public native @Cast("bool") boolean isUseMKLDNN();
public native void setUseMKLDNN(@Cast("bool") boolean useMKLDNN);
@ -1020,6 +1041,7 @@ bool verbose = false;
// #include <graph/execution/LogicExecutor.h>
// #include <graph/ResultWrapper.h>
// #include <DebugInfo.h>
// #include <memory/MemoryCounter.h>
/**
* This function returns last error code stored,
@ -3594,6 +3616,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
// #include <execution/AffinityManager.h>
// #include <memory>
// #include <array/InteropDataBuffer.h>
// #include <memory/MemoryCounter.h>
@ -4859,7 +4882,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
// #include <string>
// #include <atomic>
// #include <map>
// #include <unordered_map>
// #include <NDArray.h>
// #include <memory/Workspace.h>
// #include <dll.h>
@ -5010,6 +5033,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
// #include <chrono>
// #include <array/DataTypeUtils.h>
// #include <helpers/logger.h>
// #include <stdexcept>
// #ifdef __CUDACC__
// #endif
@ -5461,7 +5485,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
//#include <graph/Block.h>
// #include <NDArray.h>
// #include <map>
// #include <unordered_map>
// #include <string>
// #include <atomic>
// #include <pointercast.h>
@ -5552,7 +5576,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
// #include <op_boilerplate.h>
// #include <dll.h>
// #include <vector>
// #include <map>
// #include <unordered_map>
// #include <graph/Scope.h>
// #include <Status.h>
// #include <graph/VariableSpace.h>
@ -5668,7 +5692,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
// #include <string>
// #include <vector>
// #include <list>
// #include <map>
// #include <unordered_map>
// #include <mutex>
// #include <NDArray.h>
// #include <array/NDArrayList.h>
@ -11885,7 +11909,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
// #include <pointercast.h>
// #include <vector>
// #include <map>
// #include <unordered_map>
// #include <mutex>
// #include <ops/declarable/DeclarableOp.h>
// #include <ops/declarable/PlatformHelper.h>
@ -17106,6 +17130,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* Input : batched tensor with rank >=2
* Output: tensor with rank lesser by 1 from input
*/
// #if NOT_EXCLUDED(OP_matrix_diag_part)
@Namespace("nd4j::ops") public static class matrix_diag_part extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
@ -17121,7 +17146,36 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
/**
* QR decomposition: A = QR, where Q is ortogonal (Q * QT = I) and R is upper triangular.
* For A (MxN) Q is M x M and R is (NxN).
*
* Input :
* 0 - float (or complex float) tensor with shape {.,..,...,M,N} - batch of float matricies
*
* Output:
* 0 - float tensor with shape {.,..,...,MxN} - batch of ortogonal matricies {Qs}
* 1 - float tensor with shape {.,..,...,NxN} - batch of upper triangular matricies {Rs}
*/
// #if NOT_EXCLUDED(OP_qr)
@Namespace("nd4j::ops") public static class qr extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public qr(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public qr(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public qr position(long position) {
return (qr)super.position(position);
}
public qr() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
/**
* This operation takes 2 arrays: original values, and values to be excluded. And returns 2 arrays: values left after exclusion, and indices in original array for surivals.
@ -23687,7 +23741,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
// #ifndef DEV_TESTS_SHAPEDESCRIPTOR_H
// #define DEV_TESTS_SHAPEDESCRIPTOR_H
// #include <map>
// #include <unordered_map>
// #include <vector>
// #include <dll.h>
// #include <pointercast.h>

View File

@ -459,7 +459,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
@Test
public void testSubiRowVector() {
INDArray oneThroughFour = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape('c', 2, 2);
INDArray row1 = oneThroughFour.getRow(1);
INDArray row1 = oneThroughFour.getRow(1).dup();
oneThroughFour.subiRowVector(row1);
INDArray result = Nd4j.create(new double[] {-2, -2, 0, 0}, new long[] {2, 2});
assertEquals(getFailureMessage(), result, oneThroughFour);

View File

@ -17,6 +17,7 @@
package org.nd4j.linalg.api;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.Pointer;
import org.junit.Ignore;
@ -32,7 +33,7 @@ import org.nd4j.nativeblas.NativeOpsHolder;
import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.*;
/**
* Created by Alex on 30/04/2016.
@ -109,6 +110,44 @@ public class TestNDArrayCreation extends BaseNd4jTest {
assertEquals(arrCreate.data().address(), pointer.address());
}
@Test
@Ignore // this is endless test
public void testEndlessAllocation() {
Nd4j.getEnvironment().setMaxSpecialMemory(1);
while (true) {
val arr = Nd4j.createUninitialized(DataType.FLOAT, 100000000);
arr.assign(1.0f);
}
}
@Test
@Ignore("This test is designed to run in isolation. With parallel gc it makes no real sense since allocated amount changes at any time")
public void testAllocationLimits() throws Exception {
Nd4j.create(1);
val origDeviceLimit = Nd4j.getEnvironment().getDeviceLimit(0);
val origDeviceCount = Nd4j.getEnvironment().getDeviceCouner(0);
val limit = origDeviceCount + 10000;
Nd4j.getEnvironment().setDeviceLimit(0, limit);
val array = Nd4j.createUninitialized(DataType.DOUBLE, 1024);
assertNotNull(array);
try {
Nd4j.createUninitialized(DataType.DOUBLE, 1024);
assertTrue(false);
} catch (Exception e) {
//
}
// we want to be sure there's nothing left after exception
assertEquals(0, NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode());
Nd4j.getEnvironment().setDeviceLimit(0, origDeviceLimit);
}
@Override
public char ordering() {

View File

@ -82,4 +82,39 @@ public interface Environment {
/** Return true if the backend is a CPU backend, or false otherwise */
boolean isCPU();
/**
* This method allows to set memory limit for a specific group of devices. I.e. CUDA or CPU
* @param group
* @param numBytes
*/
void setGroupLimit(int group, long numBytes);
/**
* This method allows to set memory limit for a specific device. I.e. GPU_0
* @param deviceId
* @param numBytes
*/
void setDeviceLimit(int deviceId, long numBytes);
/**
* This method returns current group limit
* @param group
* @return
*/
long getGroupLimit(int group);
/**
* This method returns current device limit
* @param deviceId
* @return
*/
long getDeviceLimit(int deviceId);
/**
* This method returns current allocated amount for a specific device. I.e. GPU_0
* @param deviceId
* @return
*/
long getDeviceCouner(int deviceId);
}