cavis/libnd4j/include/array/DataBuffer.h

374 lines
12 KiB
C++

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#ifndef DEV_TESTS_DATABUFFER_H
#define DEV_TESTS_DATABUFFER_H
#include <cstring>
#include <op_boilerplate.h>
#include <dll.h>
#include <pointercast.h>
#include <array/DataType.h>
#include <memory/Workspace.h>
#include <execution/LaunchContext.h>
namespace nd4j {
class ND4J_EXPORT DataBuffer {
private:
void* _primaryBuffer;
void* _specialBuffer;
size_t _lenInBytes;
DataType _dataType;
memory::Workspace* _workspace;
bool _isOwnerPrimary;
bool _isOwnerSpecial;
#ifdef __CUDABLAS__
mutable std::atomic<Nd4jLong> _counter;
mutable std::atomic<Nd4jLong> _writePrimary;
mutable std::atomic<Nd4jLong> _writeSpecial;
mutable std::atomic<Nd4jLong> _readPrimary;
mutable std::atomic<Nd4jLong> _readSpecial;
#endif
void setCountersToZero();
void copyCounters(const DataBuffer& other);
void deleteSpecial();
FORCEINLINE void deletePrimary();
FORCEINLINE void deleteBuffers();
FORCEINLINE void setAllocFlags(const bool isOwnerPrimary, const bool isOwnerSpecial = false);
void allocateBuffers(const bool allocBoth = false);
void setSpecial(void* special, const bool isOwnerSpecial);
void copyBufferFromHost(const void* hostBuffer, size_t sizeToCopyinBytes = 0, const Nd4jLong offsetThis = 0, const Nd4jLong offsetHostBuffer = 0);
public:
FORCEINLINE DataBuffer(void* primary, void* special,
const size_t lenInBytes, const DataType dataType,
const bool isOwnerPrimary = false, const bool isOwnerSpecial = false,
memory::Workspace* workspace = nullptr);
FORCEINLINE DataBuffer(void* primary,
const size_t lenInBytes, const DataType dataType,
const bool isOwnerPrimary = false,
memory::Workspace* workspace = nullptr);
FORCEINLINE DataBuffer(const void* hostBuffer, // copies data from hostBuffer to own memory buffer
const DataType dataType, const size_t lenInBytes,
memory::Workspace* workspace = nullptr);
FORCEINLINE DataBuffer(const size_t lenInBytes, const DataType dataType, memory::Workspace* workspace = nullptr, const bool allocBoth = false);
FORCEINLINE DataBuffer(const DataBuffer& other);
FORCEINLINE DataBuffer(DataBuffer&& other);
FORCEINLINE explicit DataBuffer();
FORCEINLINE ~DataBuffer();
FORCEINLINE DataBuffer& operator=(const DataBuffer& other);
FORCEINLINE DataBuffer& operator=(DataBuffer&& other) noexcept;
FORCEINLINE DataType getDataType();
FORCEINLINE size_t getLenInBytes() const;
FORCEINLINE void* primary();
FORCEINLINE void* special();
FORCEINLINE void allocatePrimary();
void allocateSpecial();
void writePrimary() const;
void writeSpecial() const;
void readPrimary() const;
void readSpecial() const;
bool isPrimaryActual() const;
bool isSpecialActual() const;
void migrate();
template <typename T> FORCEINLINE T* primaryAsT();
template <typename T> FORCEINLINE T* specialAsT();
void syncToPrimary(const LaunchContext* context, const bool forceSync = false);
void syncToSpecial(const bool forceSync = false);
void setToZeroBuffers(const bool both = false);
void copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes = 0, const Nd4jLong offsetThis = 0, const Nd4jLong offsetOther = 0);
};
///// IMLEMENTATION OF INLINE METHODS /////
////////////////////////////////////////////////////////////////////////
// default constructor
DataBuffer::DataBuffer() {
_primaryBuffer = nullptr;
_specialBuffer = nullptr;
_lenInBytes = 0;
_dataType = INT8;
_workspace = nullptr;
_isOwnerPrimary = false;
_isOwnerSpecial = false;
setCountersToZero();
}
////////////////////////////////////////////////////////////////////////
// copy constructor
DataBuffer::DataBuffer(const DataBuffer &other) {
throw std::runtime_error("DataBuffer copy constructor: we don't expect using of this constructor!");
_lenInBytes = other._lenInBytes;
_dataType = other._dataType;
_workspace = other._workspace;
_primaryBuffer = nullptr;
_specialBuffer = nullptr;
setCountersToZero();
allocateBuffers();
copyBufferFrom(other);
}
////////////////////////////////////////////////////////////////////////
DataBuffer::DataBuffer(void* primary, void* special,
const size_t lenInBytes, const DataType dataType,
const bool isOwnerPrimary, const bool isOwnerSpecial,
memory::Workspace* workspace) {
if (primary == nullptr && special == nullptr)
throw std::runtime_error("DataBuffer constructor: can't be initialized with both nullptr buffers !");
_primaryBuffer = primary;
_specialBuffer = special;
_lenInBytes = lenInBytes;
_dataType = dataType;
_workspace = workspace;
_isOwnerPrimary = isOwnerPrimary;
_isOwnerSpecial = isOwnerSpecial;
setCountersToZero();
if(primary != nullptr)
readPrimary();
if(special != nullptr)
readSpecial();
}
////////////////////////////////////////////////////////////////////////
DataBuffer::DataBuffer(void* primary, const size_t lenInBytes, const DataType dataType, const bool isOwnerPrimary, memory::Workspace* workspace):
DataBuffer(primary, nullptr, lenInBytes, dataType, isOwnerPrimary, false, workspace) {
syncToSpecial(true);
}
////////////////////////////////////////////////////////////////////////
// copies data from hostBuffer to own memory buffer
DataBuffer::DataBuffer(const void* hostBuffer, const DataType dataType, const size_t lenInBytes, memory::Workspace* workspace) {
if (hostBuffer == nullptr)
throw std::runtime_error("DataBuffer constructor: can't be initialized with nullptr host buffer !");
if (lenInBytes == 0)
throw std::runtime_error("DataBuffer constructor: can't be initialized with zero length !");
_primaryBuffer = nullptr;
_specialBuffer = nullptr;
_lenInBytes = lenInBytes;
_dataType = dataType;
_workspace = workspace;
setCountersToZero();
allocateBuffers();
copyBufferFromHost(hostBuffer, lenInBytes);
}
////////////////////////////////////////////////////////////////////////
DataBuffer::DataBuffer(const size_t lenInBytes, const DataType dataType, memory::Workspace* workspace, const bool allocBoth) {
_dataType = dataType;
_workspace = workspace;
_lenInBytes = lenInBytes;
_primaryBuffer = nullptr;
_specialBuffer = nullptr;
setCountersToZero();
if(lenInBytes != 0) {
allocateBuffers(allocBoth);
writeSpecial();
}
}
////////////////////////////////////////////////////////////////////////
// move constructor
DataBuffer::DataBuffer(DataBuffer&& other) {
_primaryBuffer = other._primaryBuffer;
_specialBuffer = other._specialBuffer;
_lenInBytes = other._lenInBytes;
_dataType = other._dataType;
_workspace = other._workspace;
_isOwnerPrimary = other._isOwnerPrimary;
_isOwnerSpecial = other._isOwnerSpecial;
copyCounters(other);
other._primaryBuffer = other._specialBuffer = nullptr;
other.setAllocFlags(false, false);
other._lenInBytes = 0;
}
////////////////////////////////////////////////////////////////////////
// assignment operator
DataBuffer& DataBuffer::operator=(const DataBuffer& other) {
if (this == &other)
return *this;
deleteBuffers();
_lenInBytes = other._lenInBytes;
_dataType = other._dataType;
_workspace = other._workspace;
allocateBuffers();
copyBufferFrom(other);
return *this;
}
////////////////////////////////////////////////////////////////////////
// move assignment operator
DataBuffer& DataBuffer::operator=(DataBuffer&& other) noexcept {
if (this == &other)
return *this;
deleteBuffers();
_primaryBuffer = other._primaryBuffer;
_specialBuffer = other._specialBuffer;
_lenInBytes = other._lenInBytes;
_dataType = other._dataType;
_workspace = other._workspace;
_isOwnerPrimary = other._isOwnerPrimary;
_isOwnerSpecial = other._isOwnerSpecial;
copyCounters(other);
other._primaryBuffer = other._specialBuffer = nullptr;
other.setAllocFlags(false, false);
other._lenInBytes = 0;
return *this;
}
////////////////////////////////////////////////////////////////////////
void* DataBuffer::primary() {
return _primaryBuffer;
}
////////////////////////////////////////////////////////////////////////
void* DataBuffer::special() {
return _specialBuffer;
}
////////////////////////////////////////////////////////////////////////
DataType DataBuffer::getDataType() {
return _dataType;
}
////////////////////////////////////////////////////////////////////////
size_t DataBuffer::getLenInBytes() const {
return _lenInBytes;
}
////////////////////////////////////////////////////////////////////////
template <typename T>
T* DataBuffer::primaryAsT() {
return reinterpret_cast<T*>(_primaryBuffer);
}
////////////////////////////////////////////////////////////////////////
template <typename T>
T* DataBuffer::specialAsT() {
return reinterpret_cast<T*>(_specialBuffer);
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::allocatePrimary() {
if (_primaryBuffer == nullptr && getLenInBytes() > 0) {
ALLOCATE(_primaryBuffer, _workspace, getLenInBytes(), int8_t);
_isOwnerPrimary = true;
}
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::setAllocFlags(const bool isOwnerPrimary, const bool isOwnerSpecial) {
_isOwnerPrimary = isOwnerPrimary;
_isOwnerSpecial = isOwnerSpecial;
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::deletePrimary() {
if(_isOwnerPrimary && _primaryBuffer != nullptr && getLenInBytes() != 0) {
auto p = reinterpret_cast<int8_t*>(_primaryBuffer);
RELEASE(p, _workspace);
_primaryBuffer = nullptr;
_isOwnerPrimary = false;
}
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::deleteBuffers() {
deletePrimary();
deleteSpecial();
_lenInBytes = 0;
}
////////////////////////////////////////////////////////////////////////
DataBuffer::~DataBuffer() {
deleteBuffers();
}
}
#endif //DEV_TESTS_DATABUFFER_H