cavis/libnd4j/include/memory/MemoryCounter.h

147 lines
4.8 KiB
C++

/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef SD_MEMORYCOUNTER_H
#define SD_MEMORYCOUNTER_H
#include <pointercast.h>
#include <dll.h>
#include <map>
#include <memory/MemoryType.h>
#include <mutex>
namespace nd4j {
namespace memory {
/**
* This class provides simple per-device counter
*/
class ND4J_EXPORT MemoryCounter {
private:
static MemoryCounter* _INSTANCE;
// used for synchronization
std::mutex _locker;
// per-device counters
std::map<int, Nd4jLong> _deviceCounters;
// TODO: change this wrt heterogenous stuff on next iteration
// per-group counters
std::map<nd4j::memory::MemoryType, Nd4jLong> _groupCounters;
// per-device limits
std::map<int, Nd4jLong> _deviceLimits;
// per-group limits
std::map<nd4j::memory::MemoryType, Nd4jLong> _groupLimits;
MemoryCounter();
~MemoryCounter() = default;
public:
static MemoryCounter *getInstance();
/**
* This method checks if allocation of numBytes won't break through per-group or per-device limit
* @param numBytes
* @return TRUE if allocated ammount will keep us below limit, FALSE otherwise
*/
bool validate(Nd4jLong numBytes);
/**
* This method checks if allocation of numBytes won't break through per-device limit
* @param deviceId
* @param numBytes
* @return TRUE if allocated ammount will keep us below limit, FALSE otherwise
*/
bool validateDevice(int deviceId, Nd4jLong numBytes);
/**
* This method checks if allocation of numBytes won't break through per-group limit
* @param deviceId
* @param numBytes
* @return TRUE if allocated ammount will keep us below limit, FALSE otherwise
*/
bool validateGroup(nd4j::memory::MemoryType group, Nd4jLong numBytes);
/**
* This method adds specified number of bytes to specified counter
* @param deviceId
* @param numBytes
*/
void countIn(int deviceId, Nd4jLong numBytes);
void countIn(nd4j::memory::MemoryType group, Nd4jLong numBytes);
/**
* This method subtracts specified number of bytes from specified counter
* @param deviceId
* @param numBytes
*/
void countOut(int deviceId, Nd4jLong numBytes);
void countOut(nd4j::memory::MemoryType group, Nd4jLong numBytes);
/**
* This method returns amount of memory allocated on specified device
* @param deviceId
* @return
*/
Nd4jLong allocatedDevice(int deviceId);
/**
* This method returns amount of memory allocated in specified group of devices
* @param group
* @return
*/
Nd4jLong allocatedGroup(nd4j::memory::MemoryType group);
/**
* This method allows to set per-device memory limits
* @param deviceId
* @param numBytes
*/
void setDeviceLimit(int deviceId, Nd4jLong numBytes);
/**
* This method returns current device limit in bytes
* @param deviceId
* @return
*/
Nd4jLong deviceLimit(int deviceId);
/**
* This method allows to set per-group memory limits
* @param group
* @param numBytes
*/
void setGroupLimit(nd4j::memory::MemoryType group, Nd4jLong numBytes);
/**
* This method returns current group limit in bytes
* @param group
* @return
*/
Nd4jLong groupLimit(nd4j::memory::MemoryType group);
};
}
}
#endif //SD_MEMORYCOUNTER_H