C++ rearrangements (#485)
* initial commit Signed-off-by: raver119@gmail.com <raver119@gmail.com> * some minor singleton changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * more iterations Signed-off-by: raver119 <raver119@gmail.com> * more singletons updated Signed-off-by: raver119 <raver119@gmail.com> * more singletons updated Signed-off-by: raver119 <raver119@gmail.com> * more changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * CUDA updates Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Java side update Signed-off-by: raver119@gmail.com <raver119@gmail.com> * one commented out test Signed-off-by: raver119@gmail.com <raver119@gmail.com>master
parent
ee3e059b12
commit
ac7fb903d7
|
@ -77,7 +77,7 @@ If you're adding new ops, and want to make sure they run ok on your specific dev
|
|||
Despite being simple - it still provides you with time spent in various parts of Graph.
|
||||
|
||||
```c++
|
||||
Environment::getInstance()->setProfiling(true);
|
||||
Environment::getInstance().setProfiling(true);
|
||||
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb");
|
||||
|
||||
auto profile = GraphProfilingHelper::profile(graph, 1000);
|
||||
|
|
|
@ -22,37 +22,40 @@
|
|||
|
||||
#include <system/dll.h>
|
||||
#include <system/pointercast.h>
|
||||
#include <memory>
|
||||
#include <array/PointerWrapper.h>
|
||||
#include <array/DataType.h>
|
||||
|
||||
|
||||
namespace sd {
|
||||
class ND4J_EXPORT ConstantDataBuffer {
|
||||
private:
|
||||
Nd4jPointer _primaryBuffer = nullptr;
|
||||
Nd4jPointer _specialBuffer = nullptr;
|
||||
Nd4jLong _length = 0;
|
||||
Nd4jLong _sizeOf = 0;
|
||||
std::shared_ptr<PointerWrapper> _primaryBuffer;
|
||||
std::shared_ptr<PointerWrapper> _specialBuffer = nullptr;
|
||||
uint64_t _length = 0;
|
||||
uint8_t _sizeOf = 0;
|
||||
|
||||
public:
|
||||
ConstantDataBuffer(Nd4jPointer primary, Nd4jPointer special, Nd4jLong numEelements, Nd4jLong sizeOf);
|
||||
ConstantDataBuffer(const std::shared_ptr<PointerWrapper>& primary, uint64_t numEelements, DataType dype);
|
||||
ConstantDataBuffer(const std::shared_ptr<PointerWrapper>& primary, const std::shared_ptr<PointerWrapper>& special, uint64_t numEelements, DataType dype);
|
||||
ConstantDataBuffer(const ConstantDataBuffer &other);
|
||||
ConstantDataBuffer() = default;
|
||||
~ConstantDataBuffer() = default;
|
||||
|
||||
Nd4jLong sizeOf() const;
|
||||
Nd4jLong length() const;
|
||||
uint8_t sizeOf() const;
|
||||
uint64_t length() const;
|
||||
|
||||
Nd4jPointer primary() const;
|
||||
Nd4jPointer special() const;
|
||||
void* primary() const;
|
||||
void* special() const;
|
||||
|
||||
ConstantDataBuffer& operator=(const ConstantDataBuffer& other) = default;
|
||||
ConstantDataBuffer& operator=(ConstantDataBuffer&& other) noexcept = default;
|
||||
|
||||
template <typename T>
|
||||
T* primaryAsT() const;
|
||||
|
||||
template <typename T>
|
||||
T* primaryAsT();
|
||||
|
||||
template <typename T>
|
||||
T* specialAsT();
|
||||
T* specialAsT() const;
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#ifndef SD_ARRAY_CONSTANTOFFSETSBUFFER_H_
|
||||
#define SD_ARRAY_CONSTANTOFFSETSBUFFER_H_
|
||||
|
||||
#include <system/dll.h>
|
||||
#include <system/pointercast.h>
|
||||
#include <memory>
|
||||
#include <array/PointerWrapper.h>
|
||||
|
||||
namespace sd {
|
||||
|
||||
class ND4J_EXPORT ConstantOffsetsBuffer {
|
||||
private:
|
||||
std::shared_ptr<PointerWrapper> _primaryOffsets;
|
||||
std::shared_ptr<PointerWrapper> _specialOffsets;
|
||||
|
||||
public:
|
||||
ConstantOffsetsBuffer(const std::shared_ptr<PointerWrapper> &primary);
|
||||
ConstantOffsetsBuffer(const std::shared_ptr<PointerWrapper> &primary, const std::shared_ptr<PointerWrapper> &special);
|
||||
ConstantOffsetsBuffer() = default;
|
||||
~ConstantOffsetsBuffer() = default;
|
||||
|
||||
const Nd4jLong* primary() const;
|
||||
const Nd4jLong* special() const;
|
||||
const Nd4jLong* platform() const;
|
||||
};
|
||||
|
||||
} // namespace sd
|
||||
|
||||
#endif //SD_ARRAY_CONSTANTOFFSETSBUFFER_H_
|
|
@ -0,0 +1,49 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#ifndef SD_ARRAY_CONSTANTSHAPEBUFFER_H_
|
||||
#define SD_ARRAY_CONSTANTSHAPEBUFFER_H_
|
||||
|
||||
#include <system/dll.h>
|
||||
#include <system/pointercast.h>
|
||||
#include <array/PointerWrapper.h>
|
||||
#include <memory>
|
||||
|
||||
namespace sd {
|
||||
|
||||
class ND4J_EXPORT ConstantShapeBuffer {
|
||||
private:
|
||||
std::shared_ptr<PointerWrapper> _primaryShapeInfo;
|
||||
std::shared_ptr<PointerWrapper> _specialShapeInfo;
|
||||
|
||||
public:
|
||||
ConstantShapeBuffer(const std::shared_ptr<PointerWrapper> &primary);
|
||||
ConstantShapeBuffer(const std::shared_ptr<PointerWrapper> &primary, const std::shared_ptr<PointerWrapper> &special);
|
||||
ConstantShapeBuffer() = default;
|
||||
~ConstantShapeBuffer() = default;
|
||||
|
||||
const Nd4jLong* primary() const;
|
||||
const Nd4jLong* special() const;
|
||||
const Nd4jLong* platform() const;
|
||||
};
|
||||
|
||||
} // namespace sd
|
||||
|
||||
#endif //SD_ARRAY_CONSTANTSHAPEBUFFER_H_
|
|
@ -0,0 +1,38 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#ifndef SD_CUDAYPOINTERDEALLOCATOR_H_
|
||||
#define SD_CUDAYPOINTERDEALLOCATOR_H_
|
||||
|
||||
#include <system/dll.h>
|
||||
#include <system/pointercast.h>
|
||||
#include <array/PointerDeallocator.h>
|
||||
|
||||
namespace sd {
|
||||
class ND4J_EXPORT CudaPointerDeallocator : public PointerDeallocator {
|
||||
public:
|
||||
CudaPointerDeallocator() = default;
|
||||
~CudaPointerDeallocator() = default;
|
||||
|
||||
void release(void* ptr) override;
|
||||
};
|
||||
}
|
||||
|
||||
#endif //SD_CUDAYPOINTERDEALLOCATOR_H_
|
|
@ -110,7 +110,7 @@ namespace sd {
|
|||
// if proposed dataType is already floating point - return it
|
||||
if (isR(typeX))
|
||||
return typeX;
|
||||
return Environment::getInstance()->defaultFloatDataType();
|
||||
return Environment::getInstance().defaultFloatDataType();
|
||||
}
|
||||
|
||||
FORCEINLINE bool DataTypeUtils::isR(sd::DataType dataType) {
|
||||
|
@ -154,7 +154,7 @@ namespace sd {
|
|||
// if both data types are float - return biggest one
|
||||
if (rX && rY) {
|
||||
// if we allow precision boost, then we pick bigger data type
|
||||
if (sd::Environment::getInstance()->precisionBoostAllowed()) {
|
||||
if (sd::Environment::getInstance().precisionBoostAllowed()) {
|
||||
return nd4j_max(typeX, typeY);
|
||||
} else {
|
||||
// and we return first operand otherwise
|
||||
|
@ -165,7 +165,7 @@ namespace sd {
|
|||
|
||||
// if that's not real type, we apply same rules
|
||||
if (!rX && !rY) {
|
||||
if (sd::Environment::getInstance()->precisionBoostAllowed()) {
|
||||
if (sd::Environment::getInstance().precisionBoostAllowed()) {
|
||||
return nd4j_max(typeX, typeY);
|
||||
} else {
|
||||
// and we return first operand otherwise
|
||||
|
|
|
@ -45,6 +45,7 @@
|
|||
#include <memory>
|
||||
#include <array/InteropDataBuffer.h>
|
||||
#include <memory/MemoryCounter.h>
|
||||
#include <array/ConstantShapeBuffer.h>
|
||||
|
||||
|
||||
namespace sd {
|
||||
|
@ -155,8 +156,8 @@ namespace sd {
|
|||
/**
|
||||
* contains shape info: matrix rank, numbers of elements per each dimension, dimensions strides, element-wise-stride, c-like or fortan-like order
|
||||
*/
|
||||
Nd4jLong *_shapeInfo = nullptr;
|
||||
Nd4jLong *_shapeInfoD = nullptr;
|
||||
const Nd4jLong *_shapeInfo = nullptr;
|
||||
const Nd4jLong *_shapeInfoD = nullptr;
|
||||
|
||||
/**
|
||||
* pointer on device launch context (with all data needed there).
|
||||
|
@ -1219,7 +1220,7 @@ namespace sd {
|
|||
void setShapeInfo(const Nd4jLong *shapeInfo);
|
||||
void setShapeInfo(const Nd4jLong *shapeInfo, const sd::DataType dtype);
|
||||
void setShapeInfo(const ShapeDescriptor& descriptor);
|
||||
void setShapeInfo(const ConstantDataBuffer& shapeBuffer);
|
||||
void setShapeInfo(const ConstantShapeBuffer& shapeBuffer);
|
||||
|
||||
/**
|
||||
* returns absolute offset which corresponds to given sequential index
|
||||
|
@ -1516,9 +1517,9 @@ FORCEINLINE R NDArray::templatedGet(void const* buffer, Nd4jLong index) const {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::setShapeInfo(Nd4jLong *shapeInfo) {
|
||||
auto buffer = ConstantShapeHelper::getInstance()->bufferForShapeInfo(shapeInfo);
|
||||
_shapeInfo = buffer.primaryAsT<Nd4jLong>();
|
||||
_shapeInfoD = buffer.specialAsT<Nd4jLong>();
|
||||
auto buffer = ConstantShapeHelper::getInstance().bufferForShapeInfo(shapeInfo);
|
||||
_shapeInfo = buffer.primary();
|
||||
_shapeInfoD = buffer.special();
|
||||
|
||||
if (shapeInfo != nullptr) {
|
||||
_dataType = ArrayOptions::dataType(_shapeInfo);
|
||||
|
@ -1535,9 +1536,9 @@ void NDArray::setShapeInfo(Nd4jLong *shapeInfo) {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::setShapeInfo(Nd4jLong *shapeInfo, const sd::DataType dtype) {
|
||||
auto buffer = ConstantShapeHelper::getInstance()->bufferForShapeInfo(shapeInfo);
|
||||
_shapeInfo = buffer.primaryAsT<Nd4jLong>();
|
||||
_shapeInfoD = buffer.specialAsT<Nd4jLong>();
|
||||
auto buffer = ConstantShapeHelper::getInstance().bufferForShapeInfo(shapeInfo);
|
||||
_shapeInfo = buffer.primary();
|
||||
_shapeInfoD = buffer.special();
|
||||
|
||||
if (shapeInfo != nullptr) {
|
||||
_dataType = dtype;
|
||||
|
@ -1623,7 +1624,7 @@ bool NDArray::nonNull() const {
|
|||
if (isEmpty())
|
||||
return true;
|
||||
|
||||
if(!Environment::getInstance()->isCPU())
|
||||
if(!Environment::getInstance().isCPU())
|
||||
return getDataBuffer()->special() != nullptr && specialShapeInfo() != nullptr;
|
||||
|
||||
return getDataBuffer()->primary() != nullptr && shapeInfo() != nullptr;
|
||||
|
|
|
@ -181,7 +181,7 @@ NDArray::NDArray(sd::DataType dtype, sd::LaunchContext* context, const bool isSc
|
|||
_buffer->setToZeroBuffers();
|
||||
}
|
||||
else
|
||||
setShapeInfo(ConstantShapeHelper::getInstance()->emptyShapeInfo(dtype));
|
||||
setShapeInfo(ConstantShapeHelper::getInstance().emptyShapeInfo(dtype));
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1088,9 +1088,11 @@ void NDArray::streamline(char o) {
|
|||
char order = o == 'a' ? this->ordering() : o;
|
||||
syncToDevice();
|
||||
std::shared_ptr<DataBuffer> newBuffer = std::make_shared<DataBuffer>(this->lengthOf() * sizeOfT(), dataType(), getContext()->getWorkspace());
|
||||
auto shapeBuffer = ConstantShapeHelper::getInstance()->bufferForShapeInfo(dataType(), order, rankOf(), shapeOf());
|
||||
NativeOpExecutioner::execTransformSame(getContext(), transform::Copy, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), newBuffer->primary(), static_cast<Nd4jLong*>(shapeBuffer.primary()), newBuffer->special(), static_cast<Nd4jLong*>(shapeBuffer.special()), nullptr, nullptr, nullptr);
|
||||
setShapeInfo(static_cast<Nd4jLong*>(shapeBuffer.primary()));
|
||||
auto shapeBuffer = ConstantShapeHelper::getInstance().bufferForShapeInfo(dataType(), order, rankOf(), shapeOf());
|
||||
NativeOpExecutioner::execTransformSame(getContext(), transform::Copy, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), newBuffer->primary(),
|
||||
shapeBuffer.primary(), newBuffer->special(),
|
||||
shapeBuffer.special(), nullptr, nullptr, nullptr);
|
||||
setShapeInfo(shapeBuffer);
|
||||
_buffer = newBuffer;
|
||||
_offset = 0;
|
||||
tickWriteDevice();
|
||||
|
@ -1355,7 +1357,7 @@ NDArray NDArray::reduceAlongDimension(sd::reduce::FloatOps op, const std::vector
|
|||
|
||||
std::vector<int> copy(dimensions);
|
||||
|
||||
auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, isR() ? dataType() : Environment::getInstance()->defaultFloatDataType(), keepDims, supportOldShapes, getContext()->getWorkspace());
|
||||
auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, isR() ? dataType() : Environment::getInstance().defaultFloatDataType(), keepDims, supportOldShapes, getContext()->getWorkspace());
|
||||
|
||||
NDArray result(newShape, true, getContext());
|
||||
|
||||
|
@ -1432,7 +1434,7 @@ NDArray NDArray::reduceNumber(sd::reduce::FloatOps op, void *extraParams) const
|
|||
if (isS())
|
||||
throw std::runtime_error("NDArray::reduceNumber FloatOps: you can't use this method on String array!");
|
||||
|
||||
auto shape = ConstantShapeHelper::getInstance()->scalarShapeInfo(DataTypeUtils::pickFloatingType(dataType()));
|
||||
auto shape = ConstantShapeHelper::getInstance().scalarShapeInfo(DataTypeUtils::pickFloatingType(dataType()));
|
||||
NDArray result(shape, true, this->getContext());
|
||||
|
||||
NDArray::prepareSpecialUse({&result}, {this});
|
||||
|
@ -1461,7 +1463,7 @@ NDArray NDArray::reduceNumber(sd::reduce::BoolOps op, void *extraParams) const {
|
|||
if (isS())
|
||||
throw std::runtime_error("NDArray::reduceNumber BoolOps: you can't use this method on String array!");
|
||||
|
||||
auto shape = ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::BOOL);
|
||||
auto shape = ConstantShapeHelper::getInstance().scalarShapeInfo(DataType::BOOL);
|
||||
NDArray result(shape, true, this->getContext());
|
||||
|
||||
NDArray::prepareSpecialUse({&result}, {this});
|
||||
|
@ -1476,7 +1478,7 @@ NDArray NDArray::reduceNumber(sd::reduce::LongOps op, void *extraParams) const {
|
|||
if (isS())
|
||||
throw std::runtime_error("NDArray::reduceNumber LongOps: you can't use this method on String array!");
|
||||
|
||||
auto shape = ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT64);
|
||||
auto shape = ConstantShapeHelper::getInstance().scalarShapeInfo(DataType::INT64);
|
||||
NDArray result(shape, true, this->getContext());
|
||||
|
||||
NDArray::prepareSpecialUse({&result}, {this});
|
||||
|
@ -1854,8 +1856,7 @@ void NDArray::setAttached(bool reallyAttached) {
|
|||
//////////////////////////////////////////////////////////////////////////
|
||||
// calculate strides
|
||||
void NDArray::updateStrides(const char order) {
|
||||
shape::updateStrides(_shapeInfo, order);
|
||||
syncShape();
|
||||
throw std::runtime_error("Forbidden method");
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
@ -2456,7 +2457,7 @@ void NDArray::operator+=(const NDArray& other) {
|
|||
|
||||
if (isS())
|
||||
throw std::runtime_error("NDArray::operator+=: you can't use this method on String array!");
|
||||
if (!Environment::getInstance()->isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL))
|
||||
if (!Environment::getInstance().isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL))
|
||||
throw sd::datatype_exception::build("NDArray operator+=: Cannot add different types", this->dataType(), other.dataType());
|
||||
|
||||
if (this->lengthOf() != 1 && other.lengthOf() == 1) {
|
||||
|
@ -2490,7 +2491,7 @@ void NDArray::operator-=(const NDArray& other) {
|
|||
if (isS())
|
||||
throw std::runtime_error("NDArray::operator-=: you can't use this method on String array!");
|
||||
|
||||
if (!Environment::getInstance()->isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL))
|
||||
if (!Environment::getInstance().isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL))
|
||||
throw sd::datatype_exception::build("NDArray operator-=: Cannot subtract different types", this->dataType(), other.dataType());
|
||||
|
||||
if (lengthOf() != 1 && other.lengthOf() == 1) {
|
||||
|
@ -2523,7 +2524,7 @@ void NDArray::operator-=(const NDArray& other) {
|
|||
void NDArray::operator*=(const NDArray& other) {
|
||||
if (isS())
|
||||
throw std::runtime_error("NDArray::operator*=: you can't use this method on String array!");
|
||||
if (!Environment::getInstance()->isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL))
|
||||
if (!Environment::getInstance().isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL))
|
||||
throw sd::datatype_exception::build("NDArray operator*=: Cannot multiply different types", this->dataType(), other.dataType());
|
||||
|
||||
if (lengthOf() != 1 && other.lengthOf() == 1) {
|
||||
|
@ -2559,7 +2560,7 @@ void NDArray::operator/=(const NDArray& other) {
|
|||
if (other.isB())
|
||||
throw std::runtime_error("NDArray::operator/=: you can't divide by bool array!");
|
||||
|
||||
if (!Environment::getInstance()->isExperimentalBuild() && this->dataType() != other.dataType()) {
|
||||
if (!Environment::getInstance().isExperimentalBuild() && this->dataType() != other.dataType()) {
|
||||
throw sd::datatype_exception::build("NDArray operator/=: Cannot divide different types", this->dataType(), other.dataType());
|
||||
}
|
||||
|
||||
|
@ -2832,14 +2833,14 @@ void NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other,
|
|||
Nd4jLong const* yShapeInfoD = other.specialShapeInfo();
|
||||
|
||||
if(!isSameShape(target)) {
|
||||
auto xPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace());
|
||||
xShapeInfoH = reinterpret_cast<Nd4jLong const*>(xPack.primary());
|
||||
xShapeInfoD = reinterpret_cast<Nd4jLong const*>(xPack.special());
|
||||
auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace());
|
||||
xShapeInfoH = xPack.primary();
|
||||
xShapeInfoD = xPack.special();
|
||||
}
|
||||
if(!other.isSameShape(target)) {
|
||||
auto yPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace());
|
||||
yShapeInfoH = reinterpret_cast<Nd4jLong const*>(yPack.primary());
|
||||
yShapeInfoD = reinterpret_cast<Nd4jLong const*>(yPack.special());
|
||||
auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace());
|
||||
yShapeInfoH = yPack.primary();
|
||||
yShapeInfoD = yPack.special();
|
||||
}
|
||||
|
||||
NDArray::prepareSpecialUse({&target}, {this, &other});
|
||||
|
@ -2883,14 +2884,14 @@ void NDArray::applyTrueBroadcast(sd::BroadcastBoolOpsTuple op, const NDArray& ot
|
|||
Nd4jLong const* yShapeInfoD = other.specialShapeInfo();
|
||||
|
||||
if(!isSameShape(target)) {
|
||||
auto xPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace());
|
||||
xShapeInfoH = reinterpret_cast<Nd4jLong const*>(xPack.primary());
|
||||
xShapeInfoD = reinterpret_cast<Nd4jLong const*>(xPack.special());
|
||||
auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace());
|
||||
xShapeInfoH = xPack.primary();
|
||||
xShapeInfoD = xPack.special();
|
||||
}
|
||||
if(!other.isSameShape(target)) {
|
||||
auto yPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace());
|
||||
yShapeInfoH = reinterpret_cast<Nd4jLong const*>(yPack.primary());
|
||||
yShapeInfoD = reinterpret_cast<Nd4jLong const*>(yPack.special());
|
||||
auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace());
|
||||
yShapeInfoH = yPack.primary();
|
||||
yShapeInfoD = yPack.special();
|
||||
}
|
||||
|
||||
NDArray::prepareSpecialUse({&target}, {this, &other});
|
||||
|
@ -2934,12 +2935,12 @@ void NDArray::applyTrueBroadcast(sd::BroadcastIntOpsTuple op, const NDArray& oth
|
|||
Nd4jLong const* yShapeInfoD = other.specialShapeInfo();
|
||||
|
||||
if(!isSameShape(target)) {
|
||||
auto xPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace());
|
||||
auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace());
|
||||
xShapeInfoH = reinterpret_cast<Nd4jLong const*>(xPack.primary());
|
||||
xShapeInfoD = reinterpret_cast<Nd4jLong const*>(xPack.special());
|
||||
}
|
||||
if(!other.isSameShape(target)) {
|
||||
auto yPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace());
|
||||
auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace());
|
||||
yShapeInfoH = reinterpret_cast<Nd4jLong const*>(yPack.primary());
|
||||
yShapeInfoD = reinterpret_cast<Nd4jLong const*>(yPack.special());
|
||||
}
|
||||
|
@ -3067,7 +3068,7 @@ void NDArray::applyBroadcast(sd::broadcast::Ops op, const std::vector<int>& dime
|
|||
|
||||
// if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
|
||||
// NDArray::prepareSpecialUse({&target}, {this, &other});
|
||||
// NativeOpExecutioner::execPairwiseTransform(getContext(), fromBroadcastToPairwise(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr);
|
||||
// NativeOpExecutioner::execPairwiseTransform(getContext(), fromBroadcastToPairwise(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.special(), nullptr);
|
||||
// NDArray::registerSpecialUse({&target}, {this, &other});
|
||||
// return;
|
||||
// }
|
||||
|
@ -3088,12 +3089,12 @@ void NDArray::applyBroadcast(sd::broadcast::Ops op, const std::vector<int>& dime
|
|||
Nd4jLong const* yShapeInfoD = other.specialShapeInfo();
|
||||
|
||||
if(!isSameShape(target)) {
|
||||
auto xPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace(), copy);
|
||||
auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace(), copy);
|
||||
xShapeInfoH = reinterpret_cast<Nd4jLong const*>(xPack.primary());
|
||||
xShapeInfoD = reinterpret_cast<Nd4jLong const*>(xPack.special());
|
||||
}
|
||||
if(!other.isSameShape(target)) {
|
||||
auto yPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace(), copy);
|
||||
auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace(), copy);
|
||||
yShapeInfoH = reinterpret_cast<Nd4jLong const*>(yPack.primary());
|
||||
yShapeInfoD = reinterpret_cast<Nd4jLong const*>(yPack.special());
|
||||
}
|
||||
|
@ -3119,7 +3120,7 @@ void NDArray::applyBroadcast(sd::broadcast::BoolOps op, const std::vector<int>&
|
|||
|
||||
// if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
|
||||
// NDArray::prepareSpecialUse({&target}, {this, &other});
|
||||
// NativeOpExecutioner::execPairwiseBoolTransform(getContext(), fromBroadcastToPairwiseBool(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr);
|
||||
// NativeOpExecutioner::execPairwiseBoolTransform(getContext(), fromBroadcastToPairwiseBool(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.special(), nullptr);
|
||||
// NDArray::registerSpecialUse({&target}, {this, &other});
|
||||
// return;
|
||||
// }
|
||||
|
@ -3142,12 +3143,12 @@ void NDArray::applyBroadcast(sd::broadcast::BoolOps op, const std::vector<int>&
|
|||
Nd4jLong const* yShapeInfoD = other.specialShapeInfo();
|
||||
|
||||
if(!isSameShape(target)) {
|
||||
auto xPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace(), copy);
|
||||
auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace(), copy);
|
||||
xShapeInfoH = reinterpret_cast<Nd4jLong const*>(xPack.primary());
|
||||
xShapeInfoD = reinterpret_cast<Nd4jLong const*>(xPack.special());
|
||||
}
|
||||
if(!other.isSameShape(target)) {
|
||||
auto yPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace(), copy);
|
||||
auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace(), copy);
|
||||
yShapeInfoH = reinterpret_cast<Nd4jLong const*>(yPack.primary());
|
||||
yShapeInfoD = reinterpret_cast<Nd4jLong const*>(yPack.special());
|
||||
}
|
||||
|
@ -3174,7 +3175,7 @@ void NDArray::applyBroadcast(sd::broadcast::IntOps op, const std::vector<int>& d
|
|||
|
||||
// if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
|
||||
// NDArray::prepareSpecialUse({&target}, {this, &other});
|
||||
// NativeOpExecutioner::execPairwiseIntTransform(getContext(), fromBroadcastToPairwiseInt(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr);
|
||||
// NativeOpExecutioner::execPairwiseIntTransform(getContext(), fromBroadcastToPairwiseInt(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.special(), nullptr);
|
||||
// NDArray::registerSpecialUse({&target}, {this, &other});
|
||||
// return;
|
||||
// }
|
||||
|
@ -3197,12 +3198,12 @@ void NDArray::applyBroadcast(sd::broadcast::IntOps op, const std::vector<int>& d
|
|||
Nd4jLong const* yShapeInfoD = other.specialShapeInfo();
|
||||
|
||||
if(!isSameShape(target)) {
|
||||
auto xPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace(), copy);
|
||||
auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace(), copy);
|
||||
xShapeInfoH = reinterpret_cast<Nd4jLong const*>(xPack.primary());
|
||||
xShapeInfoD = reinterpret_cast<Nd4jLong const*>(xPack.special());
|
||||
}
|
||||
if(!other.isSameShape(target)) {
|
||||
auto yPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace(), copy);
|
||||
auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace(), copy);
|
||||
yShapeInfoH = reinterpret_cast<Nd4jLong const*>(yPack.primary());
|
||||
yShapeInfoD = reinterpret_cast<Nd4jLong const*>(yPack.special());
|
||||
}
|
||||
|
@ -3220,8 +3221,8 @@ void NDArray::applyBroadcast(sd::broadcast::Ops op, const std::initializer_list<
|
|||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void* NDArray::operator new(size_t i) {
|
||||
if (sd::memory::MemoryRegistrator::getInstance()->hasWorkspaceAttached()) {
|
||||
sd::memory::Workspace* ws = sd::memory::MemoryRegistrator::getInstance()->getWorkspace();
|
||||
if (sd::memory::MemoryRegistrator::getInstance().hasWorkspaceAttached()) {
|
||||
sd::memory::Workspace* ws = sd::memory::MemoryRegistrator::getInstance().getWorkspace();
|
||||
return ws->allocateBytes((Nd4jLong) i);
|
||||
}
|
||||
else {
|
||||
|
@ -3233,7 +3234,7 @@ void* NDArray::operator new(size_t i) {
|
|||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::operator delete(void* p) {
|
||||
if (!sd::memory::MemoryRegistrator::getInstance()->hasWorkspaceAttached())
|
||||
if (!sd::memory::MemoryRegistrator::getInstance().hasWorkspaceAttached())
|
||||
free(p);
|
||||
}
|
||||
|
||||
|
@ -3439,8 +3440,8 @@ void NDArray::varianceAlongDimension(sd::variance::Ops op, NDArray& target, cons
|
|||
NativeOpExecutioner::execSummaryStatsScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), biasCorrected);
|
||||
else {
|
||||
std::vector<int> copy(dimensions);
|
||||
auto pDims = sd::Environment::getInstance()->isCPU() ? copy.data() : nullptr;
|
||||
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(this->shapeInfo(), dimensions);
|
||||
auto pDims = sd::Environment::getInstance().isCPU() ? copy.data() : nullptr;
|
||||
auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimensions);
|
||||
NativeOpExecutioner::execSummaryStats(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), pDims, dimensions.size(), packX.platformShapeInfo(), packX.platformOffsets(), biasCorrected);
|
||||
synchronize("NDArray::varianceAlongDimension");
|
||||
}
|
||||
|
@ -4109,8 +4110,8 @@ void NDArray::applyIndexReduce(sd::indexreduce::Ops op, NDArray& target, const s
|
|||
else {
|
||||
std::vector<int> copy = dimensions;
|
||||
shape::checkDimensions(rankOf(), copy);
|
||||
auto pDims = sd::Environment::getInstance()->isCPU() ? copy.data() : nullptr;
|
||||
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(shapeInfo(), copy);
|
||||
auto pDims = sd::Environment::getInstance().isCPU() ? copy.data() : nullptr;
|
||||
auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(shapeInfo(), copy);
|
||||
NativeOpExecutioner::execIndexReduce(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), params, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets());
|
||||
synchronize("NDArray::applyIndexReduce");
|
||||
}
|
||||
|
@ -4183,10 +4184,10 @@ NDArray NDArray::applyReduce3(sd::reduce3::Ops op, const NDArray& other, const s
|
|||
}
|
||||
else {
|
||||
|
||||
auto pDims = sd::Environment::getInstance()->isCPU() ? copy.data() : nullptr;
|
||||
auto pDims = sd::Environment::getInstance().isCPU() ? copy.data() : nullptr;
|
||||
|
||||
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(shapeInfo(), copy);
|
||||
auto packY = sd::ConstantTadHelper::getInstance()->tadForDimensions(other.shapeInfo(), copy);
|
||||
auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(shapeInfo(), copy);
|
||||
auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions(other.shapeInfo(), copy);
|
||||
|
||||
if(!shape::equalsSoft(packX.primaryShapeInfo(), packY.primaryShapeInfo()) || (packX.numberOfTads() != packY.numberOfTads() && packX.numberOfTads() != 1 && packY.numberOfTads() != 1))
|
||||
throw std::runtime_error("NDArray::applyReduce3 cuda method: arrays tads are inconsistent !");
|
||||
|
@ -4212,15 +4213,15 @@ NDArray NDArray::applyAllReduce3(sd::reduce3::Ops op, const NDArray& other, cons
|
|||
shape::checkDimensions(rankOf(), copy);
|
||||
shape::checkDimensions(other.rankOf(), copy);
|
||||
|
||||
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(shapeInfo(), copy);
|
||||
auto packY = ConstantTadHelper::getInstance()->tadForDimensions(other.shapeInfo(), copy);
|
||||
auto packX = ConstantTadHelper::getInstance().tadForDimensions(shapeInfo(), copy);
|
||||
auto packY = ConstantTadHelper::getInstance().tadForDimensions(other.shapeInfo(), copy);
|
||||
|
||||
// check tads shapes
|
||||
if(!shape::equalsSoft(packX.primaryShapeInfo(), packY.primaryShapeInfo()))
|
||||
throw std::runtime_error("NDArray::applyAllReduce3 method: the shapes of array tads are different !");
|
||||
|
||||
// set newShape for output array
|
||||
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(DataTypeUtils::pickFloatingType(dataType()), 'c', {packX.numberOfTads(), packY.numberOfTads()});
|
||||
auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(DataTypeUtils::pickFloatingType(dataType()), 'c', {packX.numberOfTads(), packY.numberOfTads()});
|
||||
|
||||
// create output array
|
||||
NDArray result(newShape, true, getContext());
|
||||
|
@ -4228,7 +4229,7 @@ NDArray NDArray::applyAllReduce3(sd::reduce3::Ops op, const NDArray& other, cons
|
|||
// create dynamic array of extra parameters if array extraParams is empty (==nullptr)
|
||||
void* params = extraParams != nullptr ? const_cast<ExtraArguments*>(extraParams)->argumentsAsT(dataType()) : nullptr;
|
||||
|
||||
auto pDims = sd::Environment::getInstance()->isCPU() ? copy.data() : nullptr;
|
||||
auto pDims = sd::Environment::getInstance().isCPU() ? copy.data() : nullptr;
|
||||
|
||||
NDArray::prepareSpecialUse({&result}, {this, &other});
|
||||
NativeOpExecutioner::execReduce3All(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), params, other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
|
||||
|
@ -4260,7 +4261,7 @@ void NDArray::reduceAlongDimension(sd::reduce::FloatOps op, NDArray& target, con
|
|||
NativeOpExecutioner::execReduceFloatScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(),nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo());
|
||||
}
|
||||
else {
|
||||
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(shapeInfo(), copy);
|
||||
auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(shapeInfo(), copy);
|
||||
NativeOpExecutioner::execReduceFloat(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), copy.data(), copy.size(), packX.platformShapeInfo(), packX.platformOffsets());
|
||||
}
|
||||
synchronize("NDArray::reduceAlongDimension FloatOps");
|
||||
|
@ -4291,8 +4292,8 @@ void NDArray::reduceAlongDimension(sd::reduce::SameOps op, NDArray& target, cons
|
|||
NativeOpExecutioner::execReduceSameScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo());
|
||||
}
|
||||
else { //if (!isEmpty()) {
|
||||
auto pDims = sd::Environment::getInstance()->isCPU() ? copy.data() : nullptr;
|
||||
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(this->shapeInfo(), copy);
|
||||
auto pDims = sd::Environment::getInstance().isCPU() ? copy.data() : nullptr;
|
||||
auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), copy);
|
||||
NativeOpExecutioner::execReduceSame(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets());
|
||||
}
|
||||
synchronize("NDArray::reduceAlongDimension SameOps");
|
||||
|
@ -4323,8 +4324,8 @@ void NDArray::reduceAlongDimension(sd::reduce::LongOps op, NDArray& target, cons
|
|||
NativeOpExecutioner::execReduceLongScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo());
|
||||
}
|
||||
else {
|
||||
auto pDims = sd::Environment::getInstance()->isCPU() ? copy.data() : nullptr;
|
||||
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(this->shapeInfo(), copy);
|
||||
auto pDims = sd::Environment::getInstance().isCPU() ? copy.data() : nullptr;
|
||||
auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), copy);
|
||||
NativeOpExecutioner::execReduceLong(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets());
|
||||
}
|
||||
synchronize("NDArray::reduceAlongDimension LongOps");
|
||||
|
@ -4355,8 +4356,8 @@ void NDArray::reduceAlongDimension(sd::reduce::BoolOps op, NDArray& target, cons
|
|||
NativeOpExecutioner::execReduceBoolScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo());
|
||||
}
|
||||
else {
|
||||
auto pDims = sd::Environment::getInstance()->isCPU() ? copy.data() : nullptr;
|
||||
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(this->shapeInfo(), copy);
|
||||
auto pDims = sd::Environment::getInstance().isCPU() ? copy.data() : nullptr;
|
||||
auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), copy);
|
||||
NativeOpExecutioner::execReduceBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets());
|
||||
}
|
||||
synchronize("NDArray::reduceAlongDimension LongOps");
|
||||
|
@ -4524,7 +4525,7 @@ void NDArray::addRowVector(const NDArray& row, NDArray& target) const {
|
|||
|
||||
int dimension = 1;
|
||||
|
||||
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(this->shapeInfo(), dimension);
|
||||
auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimension);
|
||||
|
||||
NDArray::prepareSpecialUse({&target}, {this, &row});
|
||||
NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Add, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), row.specialBuffer(), row.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr);
|
||||
|
@ -4543,7 +4544,7 @@ void NDArray::subRowVector(const NDArray& row, NDArray& target) const {
|
|||
|
||||
int dimension = 1;
|
||||
|
||||
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(this->shapeInfo(), dimension);
|
||||
auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimension);
|
||||
|
||||
NDArray::prepareSpecialUse({&target}, {this, &row});
|
||||
NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Subtract, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), row.specialBuffer(), row.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), &dimension, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr);
|
||||
|
@ -4563,7 +4564,7 @@ void NDArray::mulRowVector(const NDArray &row, NDArray &target) const {
|
|||
|
||||
int dimension = 1;
|
||||
|
||||
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(this->shapeInfo(), dimension);
|
||||
auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimension);
|
||||
|
||||
NDArray::prepareSpecialUse({&target}, {this, &row});
|
||||
NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Multiply, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), row.specialBuffer(), row.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr);
|
||||
|
@ -4584,7 +4585,7 @@ void NDArray::divRowVector(const NDArray &row, NDArray &target) const {
|
|||
|
||||
int dimension = 1;
|
||||
|
||||
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(this->shapeInfo(), dimension);
|
||||
auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimension);
|
||||
|
||||
NDArray::prepareSpecialUse({&target}, {this, &row});
|
||||
NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Divide, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), row.specialBuffer(), row.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr);
|
||||
|
@ -4602,7 +4603,7 @@ void NDArray::addiRowVector(const NDArray& row) {
|
|||
|
||||
int dimension = 1;
|
||||
|
||||
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(this->shapeInfo(), dimension);
|
||||
auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimension);
|
||||
|
||||
NDArray::prepareSpecialUse({this}, {&row});
|
||||
NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Add, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), row.specialBuffer(), row.specialShapeInfo(), this->buffer(), this->shapeInfo(), this->specialBuffer(), this->specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr);
|
||||
|
@ -4620,7 +4621,7 @@ void NDArray::addColumnVector(const NDArray &column, NDArray &target) const {
|
|||
|
||||
int dimension = 0;
|
||||
|
||||
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(this->shapeInfo(), dimension);
|
||||
auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimension);
|
||||
|
||||
NDArray::prepareSpecialUse({&target}, {this, &column});
|
||||
NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Add, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), column.buffer(), column.shapeInfo(), column.specialBuffer(), column.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr);
|
||||
|
@ -4637,7 +4638,7 @@ void NDArray::addiColumnVector(const NDArray &column) {
|
|||
|
||||
int dimension = 0;
|
||||
|
||||
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(this->shapeInfo(), dimension);
|
||||
auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimension);
|
||||
|
||||
NDArray::prepareSpecialUse({this}, {&column});
|
||||
NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Add, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), column.buffer(), column.shapeInfo(), column.specialBuffer(), column.specialShapeInfo(), this->buffer(), this->shapeInfo(), this->specialBuffer(), this->specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr);
|
||||
|
@ -4654,7 +4655,7 @@ void NDArray::muliColumnVector(const NDArray& column) {
|
|||
|
||||
int dimension = 0;
|
||||
|
||||
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(this->shapeInfo(), dimension);
|
||||
auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimension);
|
||||
|
||||
NDArray::prepareSpecialUse({this}, {&column});
|
||||
NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Multiply, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), column.buffer(), column.shapeInfo(), column.specialBuffer(), column.specialShapeInfo(), this->buffer(), this->shapeInfo(), this->specialBuffer(), this->specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr);
|
||||
|
@ -4695,7 +4696,7 @@ ResultSet NDArray::multipleTensorsAlongDimension(const std::vector<int> &indices
|
|||
if (indices.size() == 0)
|
||||
return result;
|
||||
|
||||
auto pack = ConstantTadHelper::getInstance()->tadForDimensions(shapeInfo(), const_cast<int*>(dimensions.data()), dimensions.size());
|
||||
auto pack = ConstantTadHelper::getInstance().tadForDimensions(shapeInfo(), const_cast<int*>(dimensions.data()), dimensions.size());
|
||||
|
||||
auto tadLength = shape::length(pack.primaryShapeInfo());
|
||||
auto numTads = lengthOf() / tadLength;
|
||||
|
@ -4816,7 +4817,7 @@ ResultSet NDArray::allTensorsAlongDimension(const std::vector<int> &dimensions)
|
|||
throw std::runtime_error("NDArray::allTensorsAlongDimension static function: all input dimensions must be smaller than rank of input array !");
|
||||
|
||||
|
||||
auto pack = ConstantTadHelper::getInstance()->tadForDimensions(_shapeInfo, const_cast<int*>(dimensions.data()), dimensions.size());
|
||||
auto pack = ConstantTadHelper::getInstance().tadForDimensions(_shapeInfo, const_cast<int*>(dimensions.data()), dimensions.size());
|
||||
auto numTads = pack.numberOfTads();
|
||||
|
||||
for (Nd4jLong idx = 0; idx < numTads; idx++ ) {
|
||||
|
@ -4929,11 +4930,11 @@ void NDArray::setShapeInfo(const Nd4jLong *shapeInfo) {
|
|||
if (shapeInfo != nullptr) {
|
||||
|
||||
ShapeDescriptor descriptor(shapeInfo);
|
||||
auto shapeBuffer = ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor);
|
||||
auto shapeBuffer = ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor);
|
||||
|
||||
_shapeInfo = reinterpret_cast<Nd4jLong *>(shapeBuffer.primary());
|
||||
_shapeInfo = shapeBuffer.primary();
|
||||
#ifdef __CUDABLAS__
|
||||
_shapeInfoD = reinterpret_cast<Nd4jLong *>(shapeBuffer.special());
|
||||
_shapeInfoD = shapeBuffer.special();
|
||||
#endif
|
||||
|
||||
if(ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY)
|
||||
|
@ -4956,11 +4957,11 @@ void NDArray::setShapeInfo(const Nd4jLong *shapeInfo, const sd::DataType dtype)
|
|||
|
||||
Nd4jLong* shapeInfoTemp = ShapeBuilders::copyShapeInfoAndType(shapeInfo, dtype, true, getContext()->getWorkspace());
|
||||
ShapeDescriptor descriptor(shapeInfoTemp);
|
||||
auto shapeBuffer = ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor);
|
||||
auto shapeBuffer = ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor);
|
||||
|
||||
_shapeInfo = reinterpret_cast<Nd4jLong *>(shapeBuffer.primary());
|
||||
_shapeInfo = shapeBuffer.primary();
|
||||
#ifdef __CUDABLAS__
|
||||
_shapeInfoD = reinterpret_cast<Nd4jLong *>(shapeBuffer.special());
|
||||
_shapeInfoD = shapeBuffer.special();
|
||||
#endif
|
||||
|
||||
if(ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY)
|
||||
|
@ -4979,11 +4980,11 @@ void NDArray::setShapeInfo(const Nd4jLong *shapeInfo, const sd::DataType dtype)
|
|||
//////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::setShapeInfo(const ShapeDescriptor& descriptor) {
|
||||
|
||||
auto shapeBuffer = ConstantShapeHelper::getInstance()->bufferForShapeInfo(const_cast<ShapeDescriptor &>(descriptor));
|
||||
auto shapeBuffer = ConstantShapeHelper::getInstance().bufferForShapeInfo(const_cast<ShapeDescriptor &>(descriptor));
|
||||
|
||||
_shapeInfo = reinterpret_cast<Nd4jLong *>(shapeBuffer.primary());
|
||||
_shapeInfo = shapeBuffer.primary();
|
||||
#ifdef __CUDABLAS__
|
||||
_shapeInfoD = reinterpret_cast<Nd4jLong *>(shapeBuffer.special());
|
||||
_shapeInfoD = shapeBuffer.special();
|
||||
#endif
|
||||
|
||||
if(ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY)
|
||||
|
@ -4995,11 +4996,11 @@ void NDArray::setShapeInfo(const ShapeDescriptor& descriptor) {
|
|||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::setShapeInfo(const ConstantDataBuffer& shapeBuffer) {
|
||||
void NDArray::setShapeInfo(const ConstantShapeBuffer& shapeBuffer) {
|
||||
|
||||
_shapeInfo = reinterpret_cast<Nd4jLong *>(const_cast<ConstantDataBuffer&>(shapeBuffer).primary());
|
||||
_shapeInfo = shapeBuffer.primary();
|
||||
#ifdef __CUDABLAS__
|
||||
_shapeInfoD = reinterpret_cast<Nd4jLong *>(const_cast<ConstantDataBuffer&>(shapeBuffer).special());
|
||||
_shapeInfoD = shapeBuffer.special();
|
||||
#endif
|
||||
|
||||
if(ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY)
|
||||
|
@ -5350,7 +5351,7 @@ NDArray operator+(T1&& arr1, T2&& arr2) {
|
|||
|
||||
if (arr1.isS() || arr2.isS())
|
||||
throw std::runtime_error("operator+(T&& arr1, T&& arr2): you can't use this method on String arrays!");
|
||||
if (!Environment::getInstance()->isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL))
|
||||
if (!Environment::getInstance().isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL))
|
||||
throw sd::datatype_exception::build("operator+(T&& arr1, T&& arr2): Cannot multiply different types", arr1.dataType(), arr2.dataType());
|
||||
|
||||
PointersManager pointersManager(arr1.getContext(), "operator+(T&& arr1, T&& arr2)");
|
||||
|
@ -5400,7 +5401,7 @@ NDArray operator-(T1&& arr1, T2&& arr2) {
|
|||
|
||||
if (arr1.isS() || arr2.isS())
|
||||
throw std::runtime_error("operator-(T&& arr1, T&& arr2): you can't use this method on String arrays!");
|
||||
if (!Environment::getInstance()->isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL))
|
||||
if (!Environment::getInstance().isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL))
|
||||
throw sd::datatype_exception::build("operator-(T&& arr1, T&& arr2): Cannot multiply different types", arr1.dataType(), arr2.dataType());
|
||||
|
||||
PointersManager pointersManager(arr1.getContext(), "operator-(T&& arr1, T&& arr2)");
|
||||
|
@ -5450,7 +5451,7 @@ NDArray operator*(T1&& arr1, T2&& arr2) {
|
|||
|
||||
if (arr1.isS() || arr2.isS())
|
||||
throw std::runtime_error("operator*(T&& arr1, T&& arr2): you can't use this method on String arrays!");
|
||||
if (!Environment::getInstance()->isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL))
|
||||
if (!Environment::getInstance().isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL))
|
||||
throw sd::datatype_exception::build("operator*(T&& arr1, T&& arr2): Cannot multiply different types", arr1.dataType(), arr2.dataType());
|
||||
|
||||
PointersManager pointersManager(arr1.getContext(), "operator*(T&& arr1, T&& arr2)");
|
||||
|
@ -5500,7 +5501,7 @@ NDArray operator/(T1&& arr1, T2&& arr2) {
|
|||
|
||||
if (arr1.isS() || arr2.isS())
|
||||
throw std::runtime_error("operator/(T&& arr1, T&& arr2): you can't use this method on String arrays!");
|
||||
if (!Environment::getInstance()->isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL))
|
||||
if (!Environment::getInstance().isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL))
|
||||
throw sd::datatype_exception::build("operator/(T&& arr1, T&& arr2): Cannot multiply different types", arr1.dataType(), arr2.dataType());
|
||||
|
||||
PointersManager pointersManager(arr1.getContext(), "operator/(T&& arr1, T&& arr2)");
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#ifndef SD_POINTERDEALLOCATOR_H_
|
||||
#define SD_POINTERDEALLOCATOR_H_
|
||||
|
||||
#include <system/dll.h>
|
||||
#include <system/pointercast.h>
|
||||
|
||||
namespace sd {
|
||||
|
||||
class ND4J_EXPORT PointerDeallocator {
|
||||
public:
|
||||
PointerDeallocator() = default;
|
||||
~PointerDeallocator() = default;
|
||||
|
||||
virtual void release(void* ptr);
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#endif //SD_POINTERDEALLOCATOR_H_
|
|
@ -0,0 +1,49 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#ifndef SD_ARRAY_POINTER_H_
|
||||
#define SD_ARRAY_POINTER_H_
|
||||
|
||||
#include <system/dll.h>
|
||||
#include <system/pointercast.h>
|
||||
#include <array/PointerDeallocator.h>
|
||||
#include <memory>
|
||||
|
||||
namespace sd {
|
||||
class ND4J_EXPORT PointerWrapper {
|
||||
private:
|
||||
void* _pointer = nullptr;
|
||||
std::shared_ptr<PointerDeallocator> _deallocator;
|
||||
|
||||
public:
|
||||
PointerWrapper(void* ptr, const std::shared_ptr<PointerDeallocator> &deallocator = {});
|
||||
PointerWrapper() = default;
|
||||
~PointerWrapper();
|
||||
|
||||
void* pointer() const;
|
||||
|
||||
template <typename T>
|
||||
T* pointerAsT() const {
|
||||
return reinterpret_cast<T*>(pointer());
|
||||
}
|
||||
};
|
||||
} // namespace sd
|
||||
|
||||
#endif //SD_ARRAY_POINTER_H_
|
|
@ -0,0 +1,38 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#ifndef SD_PRIMARYPOINTERDEALLOCATOR_H_
|
||||
#define SD_PRIMARYPOINTERDEALLOCATOR_H_
|
||||
|
||||
#include <system/dll.h>
|
||||
#include <system/pointercast.h>
|
||||
#include <array/PointerDeallocator.h>
|
||||
|
||||
namespace sd {
|
||||
class ND4J_EXPORT PrimaryPointerDeallocator : public PointerDeallocator {
|
||||
public:
|
||||
PrimaryPointerDeallocator() = default;
|
||||
~PrimaryPointerDeallocator() = default;
|
||||
|
||||
void release(void* ptr) override;
|
||||
};
|
||||
}
|
||||
|
||||
#endif //SD_PRIMARYPOINTERDEALLOCATOR_H_
|
|
@ -21,17 +21,18 @@
|
|||
#ifndef DEV_TESTS_TADPACK_H
|
||||
#define DEV_TESTS_TADPACK_H
|
||||
|
||||
#include "ConstantDataBuffer.h"
|
||||
#include <array/ConstantOffsetsBuffer.h>
|
||||
#include <array/ConstantShapeBuffer.h>
|
||||
|
||||
namespace sd {
|
||||
class ND4J_EXPORT TadPack {
|
||||
private:
|
||||
ConstantDataBuffer _tadShape;
|
||||
ConstantDataBuffer _tadOffsets;
|
||||
ConstantShapeBuffer _tadShape;
|
||||
ConstantOffsetsBuffer _tadOffsets;
|
||||
Nd4jLong _numTads = 0 ;
|
||||
int _shapeInfoLength = 0;
|
||||
public:
|
||||
explicit TadPack(ConstantDataBuffer &shapes, ConstantDataBuffer &offets, Nd4jLong numTads);
|
||||
explicit TadPack(const ConstantShapeBuffer &shapes, const ConstantOffsetsBuffer &offets, Nd4jLong numTads);
|
||||
TadPack() = default;
|
||||
~TadPack() = default;
|
||||
|
||||
|
|
|
@ -338,7 +338,7 @@ void NDArray::tile(const std::vector<Nd4jLong>& reps, NDArray& target) const {
|
|||
const int ews = target.ews();
|
||||
const auto targetLen = target.lengthOf();
|
||||
if(target.ordering() == 'c' && ews == 1) { // ews == 1 always here
|
||||
//#pragma omp parallel for simd if(targetLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)
|
||||
//#pragma omp parallel for simd if(targetLen > Environment::getInstance().elementwiseThreshold()) schedule(guided)
|
||||
for(Nd4jLong i=0; i<targetLen; ++i) {
|
||||
auto yOffset = shape::subArrayOffset(i, target.shapeInfo(), shapeInfo());
|
||||
BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), templatedDoubleAssign, (target.buffer(), i, buffer(), yOffset), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <array/CudaPointerDeallocator.h>
|
||||
|
||||
namespace sd {
|
||||
|
||||
void CudaPointerDeallocator::release(void *ptr) {
|
||||
cudaFree(ptr);
|
||||
}
|
||||
|
||||
} // namespace sd
|
|
@ -70,16 +70,16 @@ void DataBuffer::allocateSpecial() {
|
|||
auto deviceId = sd::AffinityManager::currentDeviceId();
|
||||
|
||||
if (_workspace == nullptr)
|
||||
if (!sd::memory::MemoryCounter::getInstance()->validate(getLenInBytes()))
|
||||
throw sd::allocation_exception::build("Requested amount exceeds device limits", sd::memory::MemoryCounter::getInstance()->deviceLimit(deviceId), getLenInBytes());
|
||||
if (!sd::memory::MemoryCounter::getInstance().validate(getLenInBytes()))
|
||||
throw sd::allocation_exception::build("Requested amount exceeds device limits", sd::memory::MemoryCounter::getInstance().deviceLimit(deviceId), getLenInBytes());
|
||||
|
||||
|
||||
ALLOCATE_SPECIAL(_specialBuffer, _workspace, getLenInBytes(), int8_t);
|
||||
_isOwnerSpecial = true;
|
||||
|
||||
if (_workspace == nullptr) {
|
||||
sd::memory::MemoryCounter::getInstance()->countIn(deviceId, getLenInBytes());
|
||||
sd::memory::MemoryCounter::getInstance()->countIn(sd::memory::MemoryType::DEVICE, getLenInBytes());
|
||||
sd::memory::MemoryCounter::getInstance().countIn(deviceId, getLenInBytes());
|
||||
sd::memory::MemoryCounter::getInstance().countIn(sd::memory::MemoryType::DEVICE, getLenInBytes());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -135,8 +135,8 @@ void DataBuffer::deleteSpecial() {
|
|||
|
||||
// count out towards DataBuffer device, only if we're not in workspace
|
||||
if (_workspace == nullptr) {
|
||||
sd::memory::MemoryCounter::getInstance()->countOut(_deviceId, getLenInBytes());
|
||||
sd::memory::MemoryCounter::getInstance()->countOut(sd::memory::MemoryType::DEVICE, getLenInBytes());
|
||||
sd::memory::MemoryCounter::getInstance().countOut(_deviceId, getLenInBytes());
|
||||
sd::memory::MemoryCounter::getInstance().countOut(sd::memory::MemoryType::DEVICE, getLenInBytes());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -53,7 +53,7 @@ void* NDArray::platformBuffer() { return specialBuffer(); }
|
|||
void const* NDArray::platformBuffer() const { return specialBuffer(); }
|
||||
|
||||
Nd4jLong const* NDArray::platformShapeInfo() const { return specialShapeInfo(); }
|
||||
//Nd4jLong const* NDArray::platformShapeInfo() { return specialShapeInfo(); }
|
||||
//Nd4jLong const* NDArray::platform() { return special(); }
|
||||
|
||||
void NDArray::syncToDevice() const {
|
||||
auto currentDeviceId = AffinityManager::currentDeviceId();
|
||||
|
|
|
@ -18,29 +18,38 @@
|
|||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include "../ConstantDataBuffer.h"
|
||||
#include <array/ConstantDataBuffer.h>
|
||||
#include <array/DataTypeUtils.h>
|
||||
|
||||
namespace sd {
|
||||
ConstantDataBuffer::ConstantDataBuffer(Nd4jPointer primary, Nd4jPointer special, Nd4jLong numEelements, Nd4jLong sizeOf) {
|
||||
_primaryBuffer = primary;
|
||||
_specialBuffer = special;
|
||||
_length = numEelements;
|
||||
_sizeOf = sizeOf;
|
||||
ConstantDataBuffer::ConstantDataBuffer(
|
||||
const std::shared_ptr<PointerWrapper>& primary,
|
||||
uint64_t numEelements,
|
||||
DataType dtype) : ConstantDataBuffer(primary, {}, numEelements, dtype) {
|
||||
//
|
||||
}
|
||||
|
||||
ConstantDataBuffer::ConstantDataBuffer(
|
||||
const std::shared_ptr<PointerWrapper>& primary,
|
||||
const std::shared_ptr<PointerWrapper>& special,
|
||||
uint64_t numEelements,
|
||||
DataType dtype) : _primaryBuffer(primary), _specialBuffer(special), _length(numEelements) {
|
||||
_sizeOf = DataTypeUtils::sizeOf(dtype);
|
||||
}
|
||||
|
||||
Nd4jPointer ConstantDataBuffer::primary() const {
|
||||
return _primaryBuffer;
|
||||
void* ConstantDataBuffer::primary() const {
|
||||
return _primaryBuffer->pointer();
|
||||
}
|
||||
|
||||
Nd4jPointer ConstantDataBuffer::special() const {
|
||||
return _specialBuffer;
|
||||
void* ConstantDataBuffer::special() const {
|
||||
return _specialBuffer ? _specialBuffer->pointer() : nullptr;
|
||||
}
|
||||
|
||||
Nd4jLong ConstantDataBuffer::sizeOf() const {
|
||||
uint8_t ConstantDataBuffer::sizeOf() const {
|
||||
return _sizeOf;
|
||||
}
|
||||
|
||||
Nd4jLong ConstantDataBuffer::length() const {
|
||||
uint64_t ConstantDataBuffer::length() const {
|
||||
return _length;
|
||||
}
|
||||
|
||||
|
@ -52,21 +61,21 @@ namespace sd {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
T* ConstantDataBuffer::primaryAsT() {
|
||||
return reinterpret_cast<T*>(_primaryBuffer);
|
||||
T* ConstantDataBuffer::primaryAsT() const {
|
||||
return reinterpret_cast<T*>(_primaryBuffer->pointer());
|
||||
}
|
||||
template ND4J_EXPORT float* ConstantDataBuffer::primaryAsT<float>();
|
||||
template ND4J_EXPORT double* ConstantDataBuffer::primaryAsT<double>();
|
||||
template ND4J_EXPORT int* ConstantDataBuffer::primaryAsT<int>();
|
||||
template ND4J_EXPORT Nd4jLong* ConstantDataBuffer::primaryAsT<Nd4jLong>();
|
||||
template ND4J_EXPORT float* ConstantDataBuffer::primaryAsT<float>() const;
|
||||
template ND4J_EXPORT double* ConstantDataBuffer::primaryAsT<double>() const;
|
||||
template ND4J_EXPORT int* ConstantDataBuffer::primaryAsT<int>() const;
|
||||
template ND4J_EXPORT Nd4jLong* ConstantDataBuffer::primaryAsT<Nd4jLong>() const;
|
||||
|
||||
template <typename T>
|
||||
T* ConstantDataBuffer::specialAsT() {
|
||||
return reinterpret_cast<T*>(_specialBuffer);
|
||||
T* ConstantDataBuffer::specialAsT() const {
|
||||
return reinterpret_cast<T*>(special());
|
||||
}
|
||||
template ND4J_EXPORT float* ConstantDataBuffer::specialAsT<float>();
|
||||
template ND4J_EXPORT double* ConstantDataBuffer::specialAsT<double>();
|
||||
template ND4J_EXPORT int* ConstantDataBuffer::specialAsT<int>();
|
||||
template ND4J_EXPORT Nd4jLong* ConstantDataBuffer::specialAsT<Nd4jLong>();
|
||||
template ND4J_EXPORT float* ConstantDataBuffer::specialAsT<float>() const;
|
||||
template ND4J_EXPORT double* ConstantDataBuffer::specialAsT<double>() const;
|
||||
template ND4J_EXPORT int* ConstantDataBuffer::specialAsT<int>() const;
|
||||
template ND4J_EXPORT Nd4jLong* ConstantDataBuffer::specialAsT<Nd4jLong>() const;
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <array/ConstantOffsetsBuffer.h>
|
||||
|
||||
namespace sd {
|
||||
ConstantOffsetsBuffer::ConstantOffsetsBuffer(const std::shared_ptr<PointerWrapper> &primary) :
|
||||
ConstantOffsetsBuffer(primary, std::shared_ptr<PointerWrapper>(nullptr)) {
|
||||
//
|
||||
}
|
||||
|
||||
ConstantOffsetsBuffer::ConstantOffsetsBuffer(const std::shared_ptr<PointerWrapper> &primary,
|
||||
const std::shared_ptr<PointerWrapper> &special) {
|
||||
_primaryOffsets = primary;
|
||||
_specialOffsets = special;
|
||||
}
|
||||
|
||||
const Nd4jLong *ConstantOffsetsBuffer::primary() const {
|
||||
return reinterpret_cast<Nd4jLong*>(_primaryOffsets->pointer());
|
||||
}
|
||||
|
||||
const Nd4jLong *ConstantOffsetsBuffer::special() const {
|
||||
return _specialOffsets ? reinterpret_cast<Nd4jLong*>(_specialOffsets->pointer()) : nullptr;
|
||||
}
|
||||
|
||||
const Nd4jLong *ConstantOffsetsBuffer::platform() const {
|
||||
#ifdef __CUDABLAS__
|
||||
return special();
|
||||
#else
|
||||
return primary();
|
||||
#endif // CUDABLAS
|
||||
}
|
||||
|
||||
} // namespace sd
|
|
@ -0,0 +1,51 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <array/ConstantShapeBuffer.h>
|
||||
|
||||
namespace sd {
|
||||
ConstantShapeBuffer::ConstantShapeBuffer(const std::shared_ptr<PointerWrapper> &primary) :
|
||||
ConstantShapeBuffer(primary, std::shared_ptr<PointerWrapper>(nullptr)) {
|
||||
//
|
||||
}
|
||||
|
||||
ConstantShapeBuffer::ConstantShapeBuffer(const std::shared_ptr<PointerWrapper> &primary,
|
||||
const std::shared_ptr<PointerWrapper> &special) {
|
||||
_primaryShapeInfo = primary;
|
||||
_specialShapeInfo = special;
|
||||
}
|
||||
|
||||
const Nd4jLong *ConstantShapeBuffer::primary() const {
|
||||
return reinterpret_cast<Nd4jLong*>(_primaryShapeInfo->pointer());
|
||||
}
|
||||
|
||||
const Nd4jLong *ConstantShapeBuffer::special() const {
|
||||
return _specialShapeInfo ? reinterpret_cast<Nd4jLong*>(_specialShapeInfo->pointer()) : nullptr;
|
||||
}
|
||||
|
||||
const Nd4jLong *ConstantShapeBuffer::platform() const {
|
||||
#ifdef __CUDABLAS__
|
||||
return special();
|
||||
#else
|
||||
return primary();
|
||||
#endif // CUDABLAS
|
||||
}
|
||||
|
||||
} // namespace sd
|
|
@ -237,14 +237,14 @@ namespace sd {
|
|||
auto deviceId = sd::AffinityManager::currentDeviceId();
|
||||
// check if this allocation won't bring us above limit
|
||||
if (_workspace == nullptr) {
|
||||
if (Environment::getInstance()->isCPU()) {
|
||||
if (Environment::getInstance().isCPU()) {
|
||||
// on cpu backend we validate against device 0 for now
|
||||
if (!sd::memory::MemoryCounter::getInstance()->validate(getLenInBytes()))
|
||||
throw sd::allocation_exception::build("Requested amount exceeds HOST device limits", sd::memory::MemoryCounter::getInstance()->deviceLimit(deviceId), getLenInBytes());
|
||||
if (!sd::memory::MemoryCounter::getInstance().validate(getLenInBytes()))
|
||||
throw sd::allocation_exception::build("Requested amount exceeds HOST device limits", sd::memory::MemoryCounter::getInstance().deviceLimit(deviceId), getLenInBytes());
|
||||
} else {
|
||||
// in heterogenous mode we valdate against device group
|
||||
if (!sd::memory::MemoryCounter::getInstance()->validateGroup(sd::memory::MemoryType::HOST, getLenInBytes()))
|
||||
throw sd::allocation_exception::build("Requested amount exceeds HOST group limits", sd::memory::MemoryCounter::getInstance()->groupLimit(sd::memory::MemoryType::HOST), getLenInBytes());
|
||||
if (!sd::memory::MemoryCounter::getInstance().validateGroup(sd::memory::MemoryType::HOST, getLenInBytes()))
|
||||
throw sd::allocation_exception::build("Requested amount exceeds HOST group limits", sd::memory::MemoryCounter::getInstance().groupLimit(sd::memory::MemoryType::HOST), getLenInBytes());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -253,10 +253,10 @@ namespace sd {
|
|||
|
||||
// count in towards current deviceId if we're not in workspace mode
|
||||
if (_workspace == nullptr) {
|
||||
if (Environment::getInstance()->isCPU()) // we don't want this counter to be added to CUDA device
|
||||
sd::memory::MemoryCounter::getInstance()->countIn(deviceId, getLenInBytes());
|
||||
if (Environment::getInstance().isCPU()) // we don't want this counter to be added to CUDA device
|
||||
sd::memory::MemoryCounter::getInstance().countIn(deviceId, getLenInBytes());
|
||||
|
||||
sd::memory::MemoryCounter::getInstance()->countIn(sd::memory::MemoryType::HOST, getLenInBytes());
|
||||
sd::memory::MemoryCounter::getInstance().countIn(sd::memory::MemoryType::HOST, getLenInBytes());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -279,10 +279,10 @@ namespace sd {
|
|||
|
||||
// count out towards DataBuffer device, only if we're not in workspace
|
||||
if (_workspace == nullptr) {
|
||||
if (Environment::getInstance()->isCPU())
|
||||
sd::memory::MemoryCounter::getInstance()->countOut(_deviceId, getLenInBytes());
|
||||
if (Environment::getInstance().isCPU())
|
||||
sd::memory::MemoryCounter::getInstance().countOut(_deviceId, getLenInBytes());
|
||||
|
||||
sd::memory::MemoryCounter::getInstance()->countOut(sd::memory::MemoryType::HOST, getLenInBytes());
|
||||
sd::memory::MemoryCounter::getInstance().countOut(sd::memory::MemoryType::HOST, getLenInBytes());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <array/PointerDeallocator.h>
|
||||
|
||||
namespace sd {
|
||||
|
||||
void PointerDeallocator::release(void *ptr) {
|
||||
// noop
|
||||
}
|
||||
|
||||
} // namespace sd
|
|
@ -0,0 +1,37 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <array/PointerWrapper.h>
|
||||
|
||||
namespace sd {
|
||||
PointerWrapper::PointerWrapper(void *ptr, const std::shared_ptr<PointerDeallocator> &deallocator): _pointer(ptr), _deallocator(deallocator) {
|
||||
//
|
||||
}
|
||||
|
||||
PointerWrapper::~PointerWrapper() {
|
||||
if (_deallocator.get() != nullptr)
|
||||
_deallocator->release(_pointer);
|
||||
}
|
||||
|
||||
void *PointerWrapper::pointer() const {
|
||||
return _pointer;
|
||||
}
|
||||
|
||||
} // namespace sd
|
|
@ -0,0 +1,29 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <array/PrimaryPointerDeallocator.h>
|
||||
|
||||
namespace sd {
|
||||
|
||||
void PrimaryPointerDeallocator::release(void *ptr) {
|
||||
delete[] reinterpret_cast<int8_t*>(ptr);
|
||||
}
|
||||
|
||||
} // namespace sd
|
|
@ -23,26 +23,24 @@
|
|||
#include <helpers/shape.h>
|
||||
|
||||
namespace sd {
|
||||
TadPack::TadPack(ConstantDataBuffer &shapes, ConstantDataBuffer &offets, Nd4jLong numTads) {
|
||||
_tadShape = shapes;
|
||||
_tadOffsets = offets;
|
||||
TadPack::TadPack(const ConstantShapeBuffer &shapes, const ConstantOffsetsBuffer &offets, Nd4jLong numTads) : _tadShape(shapes), _tadOffsets(offets) {
|
||||
_numTads = numTads;
|
||||
}
|
||||
|
||||
const Nd4jLong* TadPack::primaryShapeInfo() const {
|
||||
return reinterpret_cast<Nd4jLong *>(_tadShape.primary());
|
||||
return _tadShape.primary();
|
||||
}
|
||||
|
||||
const Nd4jLong* TadPack::primaryOffsets() const {
|
||||
return reinterpret_cast<Nd4jLong *>(_tadOffsets.primary());
|
||||
return _tadOffsets.primary();
|
||||
}
|
||||
|
||||
const Nd4jLong* TadPack::specialShapeInfo() const {
|
||||
return reinterpret_cast<Nd4jLong *>(_tadShape.special());
|
||||
return _tadShape.special();
|
||||
}
|
||||
|
||||
const Nd4jLong* TadPack::specialOffsets() const {
|
||||
return reinterpret_cast<Nd4jLong *>(_tadOffsets.special());
|
||||
return _tadOffsets.special();
|
||||
}
|
||||
|
||||
Nd4jLong TadPack::numberOfTads() const {
|
||||
|
@ -50,11 +48,11 @@ namespace sd {
|
|||
}
|
||||
|
||||
const Nd4jLong* TadPack::platformShapeInfo() const {
|
||||
return sd::Environment::getInstance()->isCPU() ? primaryShapeInfo() : specialShapeInfo();
|
||||
return sd::Environment::getInstance().isCPU() ? primaryShapeInfo() : specialShapeInfo();
|
||||
}
|
||||
|
||||
const Nd4jLong* TadPack::platformOffsets() const {
|
||||
return sd::Environment::getInstance()->isCPU() ? primaryOffsets() : specialOffsets();
|
||||
return sd::Environment::getInstance().isCPU() ? primaryOffsets() : specialOffsets();
|
||||
}
|
||||
|
||||
int TadPack::shapeInfoLength() const {
|
||||
|
|
|
@ -35,9 +35,7 @@
|
|||
namespace samediff {
|
||||
class ND4J_EXPORT ThreadPool {
|
||||
private:
|
||||
static ThreadPool* _INSTANCE;
|
||||
|
||||
std::vector<std::thread*> _threads;
|
||||
std::vector<std::thread> _threads;
|
||||
std::vector<BlockingQueue<CallableWithArguments*>*> _queues;
|
||||
std::vector<CallableInterface*> _interfaces;
|
||||
|
||||
|
@ -48,7 +46,7 @@ namespace samediff {
|
|||
ThreadPool();
|
||||
~ThreadPool();
|
||||
public:
|
||||
static ThreadPool* getInstance();
|
||||
static ThreadPool& getInstance();
|
||||
|
||||
/**
|
||||
* This method returns list of pointers to threads ONLY if num_threads of threads were available upon request, returning empty list otherwise
|
||||
|
|
|
@ -107,7 +107,7 @@ namespace samediff {
|
|||
* @param increment
|
||||
* @return
|
||||
*/
|
||||
static int parallel_for(FUNC_1D function, int64_t start, int64_t stop, int64_t increment = 1, uint32_t numThreads = sd::Environment::getInstance()->maxMasterThreads());
|
||||
static int parallel_for(FUNC_1D function, int64_t start, int64_t stop, int64_t increment = 1, uint32_t numThreads = sd::Environment::getInstance().maxMasterThreads());
|
||||
|
||||
/**
|
||||
* This function executes 1 dimensional loop for a given number of threads
|
||||
|
@ -119,7 +119,7 @@ namespace samediff {
|
|||
* @param numThreads
|
||||
* @return
|
||||
*/
|
||||
static int parallel_tad(FUNC_1D function, int64_t start, int64_t stop, int64_t increment = 1, uint32_t numThreads = sd::Environment::getInstance()->maxMasterThreads());
|
||||
static int parallel_tad(FUNC_1D function, int64_t start, int64_t stop, int64_t increment = 1, uint32_t numThreads = sd::Environment::getInstance().maxMasterThreads());
|
||||
|
||||
/**
|
||||
* This method will execute function splitting 2 nested loops space with multiple threads
|
||||
|
@ -134,7 +134,7 @@ namespace samediff {
|
|||
* @param inc_y
|
||||
* @return
|
||||
*/
|
||||
static int parallel_for(FUNC_2D function, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, uint64_t numThreads = sd::Environment::getInstance()->maxMasterThreads(), bool debug = false);
|
||||
static int parallel_for(FUNC_2D function, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, uint64_t numThreads = sd::Environment::getInstance().maxMasterThreads(), bool debug = false);
|
||||
|
||||
/**
|
||||
* This method will execute function splitting 3 nested loops space with multiple threads
|
||||
|
@ -152,7 +152,7 @@ namespace samediff {
|
|||
* @param inc_z
|
||||
* @return
|
||||
*/
|
||||
static int parallel_for(FUNC_3D function, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z, uint64_t numThreads = sd::Environment::getInstance()->maxMasterThreads());
|
||||
static int parallel_for(FUNC_3D function, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z, uint64_t numThreads = sd::Environment::getInstance().maxMasterThreads());
|
||||
|
||||
/**
|
||||
*
|
||||
|
@ -160,18 +160,18 @@ namespace samediff {
|
|||
* @param numThreads
|
||||
* @return
|
||||
*/
|
||||
static int parallel_do(FUNC_DO function, uint64_t numThreads = sd::Environment::getInstance()->maxMasterThreads());
|
||||
static int parallel_do(FUNC_DO function, uint64_t numThreads = sd::Environment::getInstance().maxMasterThreads());
|
||||
|
||||
static int64_t parallel_long(FUNC_RL function, FUNC_AL aggregator, int64_t start, int64_t stop, int64_t increment = 1, uint64_t numThreads = sd::Environment::getInstance()->maxMasterThreads());
|
||||
static int64_t parallel_long(FUNC_RL function, FUNC_AL aggregator, int64_t start, int64_t stop, int64_t increment = 1, uint64_t numThreads = sd::Environment::getInstance().maxMasterThreads());
|
||||
|
||||
static double parallel_double(FUNC_RD function, FUNC_AD aggregator, int64_t start, int64_t stop, int64_t increment = 1, uint64_t numThreads = sd::Environment::getInstance()->maxMasterThreads());
|
||||
static double parallel_double(FUNC_RD function, FUNC_AD aggregator, int64_t start, int64_t stop, int64_t increment = 1, uint64_t numThreads = sd::Environment::getInstance().maxMasterThreads());
|
||||
|
||||
/**
|
||||
* This method will execute function in parallel preserving the parts to be aligned increment size
|
||||
* PLEASE NOTE: this function can use smaller number of threads than requested.
|
||||
*
|
||||
*/
|
||||
static int parallel_aligned_increment(FUNC_1D function, int64_t start, int64_t stop, int64_t increment, size_t type_size = sizeof(float), uint32_t req_numThreads = sd::Environment::getInstance()->maxMasterThreads());
|
||||
static int parallel_aligned_increment(FUNC_1D function, int64_t start, int64_t stop, int64_t increment, size_t type_size = sizeof(float), uint32_t req_numThreads = sd::Environment::getInstance().maxMasterThreads());
|
||||
|
||||
};
|
||||
}
|
||||
|
|
|
@ -61,14 +61,19 @@ namespace sd {
|
|||
|
||||
}
|
||||
|
||||
LaunchContext* LaunchContext::defaultContext() {
|
||||
// TODO: we need it to be device-aware, but only once we add NUMA support for cpu
|
||||
if (LaunchContext::_contexts.empty()) {
|
||||
LaunchContext::_contexts.emplace_back(std::make_shared<LaunchContext>());
|
||||
}
|
||||
static std::mutex _lock;
|
||||
|
||||
// return context for current device
|
||||
return LaunchContext::_contexts[0].get();
|
||||
LaunchContext* LaunchContext::defaultContext() {
|
||||
{
|
||||
// synchronous block goes here
|
||||
std::lock_guard<std::mutex> lock(_lock);
|
||||
// TODO: we need it to be device-aware, but only once we add NUMA support for cpu
|
||||
if (LaunchContext::_contexts.empty())
|
||||
LaunchContext::_contexts.emplace_back(std::make_shared<LaunchContext>());
|
||||
}
|
||||
|
||||
// return context for current device
|
||||
return LaunchContext::_contexts[0].get();
|
||||
}
|
||||
|
||||
std::mutex* LaunchContext::deviceMutex() {
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -15,7 +16,7 @@
|
|||
******************************************************************************/
|
||||
|
||||
//
|
||||
// Created by raver119 on 30.11.17.
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <execution/LaunchContext.h>
|
||||
|
@ -75,36 +76,37 @@ LaunchContext::LaunchContext() {
|
|||
}
|
||||
|
||||
LaunchContext* LaunchContext::defaultContext() {
|
||||
/**
|
||||
* This method returns LaunchContext, that has multiple entities within:
|
||||
* 1) temporary buffers. they must be per-thread
|
||||
* 2) CUDA stream. it must be either per-thread or per-device
|
||||
* 3) cuBLAS handle. it must be per-device
|
||||
*/
|
||||
auto deviceId = AffinityManager::currentDeviceId();
|
||||
/**
|
||||
* This method returns LaunchContext, that has multiple entities within:
|
||||
* 1) temporary buffers. they must be per-thread
|
||||
* 2) CUDA stream. it must be either per-thread or per-device
|
||||
* 3) cuBLAS handle. it must be per-device
|
||||
*/
|
||||
auto deviceId = AffinityManager::currentDeviceId();
|
||||
|
||||
{
|
||||
// we need this block synchronous, to avoid double initialization etc
|
||||
_mutex.lock();
|
||||
std::lock_guard<std::mutex> lock(_mutex);
|
||||
if (LaunchContext::_contexts.empty()) {
|
||||
// create one context per device
|
||||
auto numDevices = AffinityManager::numberOfDevices();
|
||||
// create one context per device
|
||||
auto numDevices = AffinityManager::numberOfDevices();
|
||||
|
||||
_contexts.resize(numDevices);
|
||||
for (int e = 0; e < numDevices; e++) {
|
||||
_deviceMutexes[e] = new std::mutex();
|
||||
_contexts.resize(numDevices);
|
||||
for (int e = 0; e < numDevices; e++) {
|
||||
_deviceMutexes[e] = new std::mutex();
|
||||
|
||||
AffinityManager::setCurrentNativeDevice(e);
|
||||
AffinityManager::setCurrentNativeDevice(e);
|
||||
|
||||
LaunchContext::_contexts[e] = std::make_shared<LaunchContext>();
|
||||
}
|
||||
LaunchContext::_contexts[e] = std::make_shared<LaunchContext>();
|
||||
}
|
||||
|
||||
// don't forget to restore device back again
|
||||
AffinityManager::setCurrentNativeDevice(deviceId);
|
||||
// don't forget to restore device back again
|
||||
AffinityManager::setCurrentNativeDevice(deviceId);
|
||||
}
|
||||
_mutex.unlock();
|
||||
}
|
||||
|
||||
// return context for current device
|
||||
return LaunchContext::_contexts[deviceId].get();
|
||||
// return context for current device
|
||||
return LaunchContext::_contexts[deviceId].get();
|
||||
}
|
||||
|
||||
|
||||
|
@ -121,11 +123,11 @@ LaunchContext::LaunchContext() {
|
|||
};
|
||||
|
||||
void* LaunchContext::getCublasHandle() const {
|
||||
return CublasHelper::getInstance()->handle();
|
||||
return CublasHelper::getInstance().handle();
|
||||
};
|
||||
|
||||
void* LaunchContext::getCusolverHandle() const {
|
||||
return CublasHelper::getInstance()->solver();
|
||||
return CublasHelper::getInstance().solver();
|
||||
};
|
||||
|
||||
cudaStream_t* LaunchContext::getCudaStream() const {
|
||||
|
@ -175,7 +177,7 @@ LaunchContext::LaunchContext() {
|
|||
}
|
||||
|
||||
void* LaunchContext::getCuDnnHandle() const {
|
||||
return CublasHelper::getInstance()->cudnn();
|
||||
return CublasHelper::getInstance().cudnn();
|
||||
}
|
||||
|
||||
sd::ErrorReference* LaunchContext::errorReference() {
|
||||
|
|
|
@ -78,7 +78,7 @@ namespace samediff {
|
|||
ThreadPool::ThreadPool() {
|
||||
// TODO: number of threads must reflect number of cores for UMA system. In case of NUMA it should be per-device pool
|
||||
// FIXME: on mobile phones this feature must NOT be used
|
||||
_available = sd::Environment::getInstance()->maxThreads();
|
||||
_available = sd::Environment::getInstance().maxThreads();
|
||||
|
||||
_queues.resize(_available.load());
|
||||
_threads.resize(_available.load());
|
||||
|
@ -88,7 +88,7 @@ namespace samediff {
|
|||
for (int e = 0; e < _available.load(); e++) {
|
||||
_queues[e] = new BlockingQueue<CallableWithArguments*>(2);
|
||||
_interfaces[e] = new CallableInterface();
|
||||
_threads[e] = new std::thread(executionLoopWithInterface_, e, _interfaces[e]);
|
||||
_threads[e] = std::thread(executionLoopWithInterface_, e, _interfaces[e]);
|
||||
_tickets.push(new Ticket());
|
||||
// _threads[e] = new std::thread(executionLoop_, e, _queues[e]);
|
||||
|
||||
|
@ -125,19 +125,22 @@ namespace samediff {
|
|||
// stop each and every thread
|
||||
|
||||
// release queue and thread
|
||||
//delete _queues[e];
|
||||
//delete _threads[e];
|
||||
delete _queues[e];
|
||||
_threads[e].detach();
|
||||
//delete _interfaces[e];
|
||||
}
|
||||
|
||||
while (!_tickets.empty()) {
|
||||
auto t = _tickets.front();
|
||||
_tickets.pop();
|
||||
delete t;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
static std::mutex _lmutex;
|
||||
|
||||
ThreadPool* ThreadPool::getInstance() {
|
||||
std::unique_lock<std::mutex> lock(_lmutex);
|
||||
if (!_INSTANCE)
|
||||
_INSTANCE = new ThreadPool();
|
||||
|
||||
return _INSTANCE;
|
||||
ThreadPool& ThreadPool::getInstance() {
|
||||
static ThreadPool instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void ThreadPool::release(int numThreads) {
|
||||
|
@ -188,7 +191,4 @@ namespace samediff {
|
|||
std::unique_lock<std::mutex> lock(_lock);
|
||||
_tickets.push(ticket);
|
||||
}
|
||||
|
||||
|
||||
ThreadPool* ThreadPool::_INSTANCE = 0;
|
||||
}
|
||||
|
|
|
@ -357,7 +357,7 @@ namespace samediff {
|
|||
return 1;
|
||||
}
|
||||
|
||||
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads);
|
||||
auto ticket = ThreadPool::getInstance().tryAcquire(numThreads);
|
||||
if (ticket != nullptr) {
|
||||
// if we got our threads - we'll run our jobs here
|
||||
auto span = delta / numThreads;
|
||||
|
@ -449,7 +449,7 @@ namespace samediff {
|
|||
// but we still mimic multithreaded execution
|
||||
return numThreads;
|
||||
} else {
|
||||
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads);
|
||||
auto ticket = ThreadPool::getInstance().tryAcquire(numThreads);
|
||||
if (ticket != nullptr) {
|
||||
|
||||
for (int e = 0; e < numThreads; e++) {
|
||||
|
@ -499,7 +499,7 @@ namespace samediff {
|
|||
return 1;
|
||||
}
|
||||
|
||||
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads);
|
||||
auto ticket = ThreadPool::getInstance().tryAcquire(numThreads);
|
||||
if (ticket != nullptr) {
|
||||
auto splitLoop = ThreadsHelper::pickLoop3d(numThreads, itersX, itersY, itersZ);
|
||||
|
||||
|
@ -526,7 +526,7 @@ namespace samediff {
|
|||
}
|
||||
|
||||
int Threads::parallel_do(FUNC_DO function, uint64_t numThreads) {
|
||||
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads - 1);
|
||||
auto ticket = ThreadPool::getInstance().tryAcquire(numThreads - 1);
|
||||
if (ticket != nullptr) {
|
||||
|
||||
// submit tasks one by one
|
||||
|
@ -565,7 +565,7 @@ namespace samediff {
|
|||
if (numThreads == 1)
|
||||
return function(0, start, stop, increment);
|
||||
|
||||
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads - 1);
|
||||
auto ticket = ThreadPool::getInstance().tryAcquire(numThreads - 1);
|
||||
if (ticket == nullptr)
|
||||
return function(0, start, stop, increment);
|
||||
|
||||
|
@ -609,7 +609,7 @@ namespace samediff {
|
|||
if (numThreads == 1)
|
||||
return function(0, start, stop, increment);
|
||||
|
||||
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads - 1);
|
||||
auto ticket = ThreadPool::getInstance().tryAcquire(numThreads - 1);
|
||||
if (ticket == nullptr)
|
||||
return function(0, start, stop, increment);
|
||||
|
||||
|
@ -668,7 +668,7 @@ namespace samediff {
|
|||
numThreads = static_cast<int>(std::ceil((double)delta / spand));
|
||||
auto span = static_cast<Nd4jLong>(spand);
|
||||
|
||||
auto ticket = samediff::ThreadPool::getInstance()->tryAcquire(numThreads);
|
||||
auto ticket = samediff::ThreadPool::getInstance().tryAcquire(numThreads);
|
||||
if (ticket != nullptr) {
|
||||
//tail_add is additional value of the last part
|
||||
//it could be negative or positive
|
||||
|
|
|
@ -31,7 +31,7 @@ namespace samediff {
|
|||
|
||||
Ticket::Ticket() {
|
||||
_acquired = true;
|
||||
_interfaces.resize(sd::Environment::getInstance()->maxThreads());
|
||||
_interfaces.resize(sd::Environment::getInstance().maxThreads());
|
||||
}
|
||||
|
||||
bool Ticket::acquired() {
|
||||
|
@ -80,11 +80,11 @@ namespace samediff {
|
|||
_interfaces[e]->markAvailable();
|
||||
|
||||
// increment availability counter
|
||||
ThreadPool::getInstance()->release();
|
||||
ThreadPool::getInstance().release();
|
||||
}
|
||||
|
||||
// return this ticket back to the pool
|
||||
ThreadPool::getInstance()->release(this);
|
||||
ThreadPool::getInstance().release(this);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -61,7 +61,7 @@ namespace sd {
|
|||
std::vector<sd::DataType> _dataTypes;
|
||||
|
||||
sd::ops::OpDescriptor* _opDescriptor;
|
||||
bool _useMKLDNN = sd::Environment::getInstance()->isUseMKLDNN();
|
||||
bool _useMKLDNN = sd::Environment::getInstance().isUseMKLDNN();
|
||||
|
||||
// target engine for execution
|
||||
samediff::Engine _engine = DEFAULT_ENGINE;
|
||||
|
|
|
@ -30,7 +30,6 @@ namespace sd {
|
|||
namespace graph {
|
||||
class ND4J_EXPORT GraphHolder {
|
||||
private:
|
||||
static GraphHolder *_INSTANCE;
|
||||
MAP_IMPL<Nd4jLong, Graph *> _graphF;
|
||||
|
||||
MAP_IMPL<Nd4jLong, SimpleReadWriteLock> _locks;
|
||||
|
@ -38,7 +37,7 @@ namespace sd {
|
|||
GraphHolder() = default;
|
||||
~GraphHolder() = default;
|
||||
public:
|
||||
static GraphHolder* getInstance();
|
||||
static GraphHolder& getInstance();
|
||||
|
||||
void registerGraph(Nd4jLong graphId, Graph *graph);
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ namespace sd {
|
|||
// FIXME!!
|
||||
outputAddr.second = e;
|
||||
|
||||
if (Environment::getInstance()->isDebugAndVerbose())
|
||||
if (Environment::getInstance().isDebugAndVerbose())
|
||||
nd4j_debug("Return input: <%i, %i>; Return output: <%i, %i>\n", inputAddr.first, inputAddr.second, outputAddr.first, outputAddr.second);
|
||||
|
||||
auto varIn = __variableSpace->getVariable(inputAddr);
|
||||
|
@ -45,7 +45,7 @@ namespace sd {
|
|||
// FIXME: this is obviously wrong, we should keep depth track for backprop here
|
||||
varOut->getNDArray()->assign(varIn->getNDArray());
|
||||
|
||||
if (Environment::getInstance()->isDebugAndVerbose())
|
||||
if (Environment::getInstance().isDebugAndVerbose())
|
||||
nd4j_debug("In after: [%f]; Out after: [%f]\n", varIn->getNDArray()->meanNumber().e<float>(0), varOut->getNDArray()->meanNumber().e<float>(0));
|
||||
}
|
||||
|
||||
|
|
|
@ -96,7 +96,7 @@ namespace sd {
|
|||
// now we should take result of the Scope run, and evaluate it
|
||||
auto result = __variableSpace->getVariable(lastNode)->getNDArray();
|
||||
|
||||
if (Environment::getInstance()->isDebugAndVerbose())
|
||||
if (Environment::getInstance().isDebugAndVerbose())
|
||||
result->printBuffer("Result of the last node:");
|
||||
|
||||
// if result evaluates to 0.0 - condition returned FALSE
|
||||
|
|
|
@ -236,7 +236,7 @@ namespace sd {
|
|||
|
||||
auto v = variable(p);
|
||||
|
||||
if (Environment::getInstance()->isDebugAndVerbose() && v != nullptr && v->getNDArray() != nullptr) {
|
||||
if (Environment::getInstance().isDebugAndVerbose() && v != nullptr && v->getNDArray() != nullptr) {
|
||||
auto array = v->getNDArray();
|
||||
std::string shape_ = ShapeUtils::shapeAsString(array);
|
||||
auto type = DataTypeUtils::asString(array->dataType());
|
||||
|
|
|
@ -166,7 +166,7 @@ namespace sd {
|
|||
// aNewShape[5] = 8192; // set type as FLOAT32 by default
|
||||
// aNewShape[6] = 1;
|
||||
// aNewShape[7] = 99;
|
||||
newShape = ConstantShapeHelper::getInstance()->createShapeInfo(DataType::FLOAT32, 'c', {1,1});
|
||||
newShape = ConstantShapeHelper::getInstance().createShapeInfo(DataType::FLOAT32, 'c', {1,1});
|
||||
} else {
|
||||
auto in = node->input()->at(0);
|
||||
|
||||
|
@ -184,7 +184,7 @@ namespace sd {
|
|||
//shape::TAD tad(oldShape, node->getDimensions()->data(), node->getDimensions()->size());
|
||||
auto numTads = shape::tadLength(oldShape, node->getDimensions()->data(), node->getDimensions()->size());
|
||||
Nd4jLong shape[2] = {1, (int) numTads};
|
||||
newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(oldShape), 'c', 2, shape);
|
||||
newShape = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(oldShape), 'c', 2, shape);
|
||||
}
|
||||
|
||||
std::pair<int, int> pairAddr(node->id(), 0);
|
||||
|
@ -805,7 +805,7 @@ namespace sd {
|
|||
// we're adding final nodes of the graph. those, not used as input anywhere
|
||||
nd4j_debug("Paring nodes... \n", "");
|
||||
|
||||
if (Environment::getInstance()->isDebugAndVerbose()) {
|
||||
if (Environment::getInstance().isDebugAndVerbose()) {
|
||||
// nd4j_printv("current _output", _output);
|
||||
}
|
||||
//_output.clear();
|
||||
|
@ -852,7 +852,7 @@ namespace sd {
|
|||
|
||||
if (std::find(_output.begin(), _output.end(), node->id()) == _output.end())
|
||||
_output.emplace_back(node->id());
|
||||
} else if (Environment::getInstance()->isDebugAndVerbose()) {
|
||||
} else if (Environment::getInstance().isDebugAndVerbose()) {
|
||||
nd4j_debug("Node [%i:<%s>] has %i outputs announced:\n", v, node->name()->c_str(), node->output()->size());
|
||||
printf("{");
|
||||
for (auto s : *node->output()) {
|
||||
|
@ -1202,7 +1202,7 @@ namespace sd {
|
|||
}
|
||||
break;
|
||||
default: {
|
||||
opNameStr = std::string(EnumUtils::_OpTypeToString(node->opType()))+"{" + ops::OpRegistrator::getInstance()->local_to_string<int>((int) node->opNum()) + "}";
|
||||
opNameStr = std::string(EnumUtils::_OpTypeToString(node->opType()))+"{" + ops::OpRegistrator::getInstance().local_to_string<int>((int) node->opNum()) + "}";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1250,7 +1250,7 @@ namespace sd {
|
|||
}
|
||||
break;
|
||||
default: {
|
||||
opNameStr = std::string(EnumUtils::_OpTypeToString(node->opType()))+"{" + ops::OpRegistrator::getInstance()->local_to_string<int>((int) node->opNum()) + "}";
|
||||
opNameStr = std::string(EnumUtils::_OpTypeToString(node->opType()))+"{" + ops::OpRegistrator::getInstance().local_to_string<int>((int) node->opNum()) + "}";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1447,7 +1447,7 @@ namespace sd {
|
|||
}
|
||||
|
||||
|
||||
hash = ops::HashHelper::getInstance()->getLongHash(localStamp);
|
||||
hash = ops::HashHelper::getInstance().getLongHash(localStamp);
|
||||
|
||||
nd4j_debug("Graph hash: %lld\n", hash);
|
||||
|
||||
|
|
|
@ -88,7 +88,7 @@ namespace graph {
|
|||
|
||||
Context context(node->getContextPrototype(), variableSpace);
|
||||
|
||||
if (sd::Environment::getInstance()->isDebugAndVerbose()) {
|
||||
if (sd::Environment::getInstance().isDebugAndVerbose()) {
|
||||
//nd4j_debug("Input variables: %i\n", node->input()->size());
|
||||
printf(" Inputs: {");
|
||||
for (int e = 0; e < node->input()->size(); e++) {
|
||||
|
@ -215,10 +215,10 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace)
|
|||
}
|
||||
auto flowPath = __variableSpace->flowPath();
|
||||
|
||||
Nd4jLong tb0 = Environment::getInstance()->isProfiling() ? GraphProfile::currentTime() : 0L;
|
||||
Nd4jLong tb0 = Environment::getInstance().isProfiling() ? GraphProfile::currentTime() : 0L;
|
||||
graph->buildGraph();
|
||||
|
||||
auto footprintForward = sd::memory::MemoryRegistrator::getInstance()->getGraphMemoryFootprint(graph->hashCode());
|
||||
auto footprintForward = sd::memory::MemoryRegistrator::getInstance().getGraphMemoryFootprint(graph->hashCode());
|
||||
if (footprintForward > 0) {
|
||||
if (__variableSpace->launchContext()->getWorkspace() != nullptr) {
|
||||
// this method will work only if current workspace size is smaller then proposed value
|
||||
|
@ -228,10 +228,10 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace)
|
|||
}
|
||||
|
||||
// optionally saving graph build time
|
||||
if (Environment::getInstance()->isProfiling())
|
||||
if (Environment::getInstance().isProfiling())
|
||||
flowPath->profile()->setBuildTime(GraphProfile::relativeTime(tb0));
|
||||
|
||||
Nd4jLong timeStart = Environment::getInstance()->isProfiling() ? GraphProfile::currentTime() : 0L;
|
||||
Nd4jLong timeStart = Environment::getInstance().isProfiling() ? GraphProfile::currentTime() : 0L;
|
||||
|
||||
bool pe = graph->getExecutorConfiguration()->_executionMode == ExecutionMode_AUTO;
|
||||
|
||||
|
@ -259,10 +259,10 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace)
|
|||
|
||||
Node* node = graph->getOnion()->at(l)->at(n);
|
||||
|
||||
if (Environment::getInstance()->isProfiling())
|
||||
if (Environment::getInstance().isProfiling())
|
||||
flowPath->profile()->nodeById(node->id(), node->name()->c_str());
|
||||
|
||||
if (lastId != node->id() && Environment::getInstance()->isProfiling()) {
|
||||
if (lastId != node->id() && Environment::getInstance().isProfiling()) {
|
||||
if (lastId != -10000000)
|
||||
flowPath->profile()->nodeById(lastId)->setTotalTime(GraphProfile::relativeTime(nodeTime));
|
||||
|
||||
|
@ -458,7 +458,7 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace)
|
|||
// now we skip all branches except of this active one
|
||||
}
|
||||
|
||||
if (sd::Environment::getInstance()->isDebugAndVerbose()) {
|
||||
if (sd::Environment::getInstance().isDebugAndVerbose()) {
|
||||
|
||||
if (__variableSpace->getVariable(node->id())->hasNDArray()) {
|
||||
auto array = __variableSpace->getVariable(node->id())->getNDArray();
|
||||
|
@ -481,7 +481,7 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace)
|
|||
}
|
||||
|
||||
// optionally saving execution time
|
||||
if (Environment::getInstance()->isProfiling()) {
|
||||
if (Environment::getInstance().isProfiling()) {
|
||||
flowPath->profile()->nodeById(lastId)->setTotalTime(GraphProfile::relativeTime(nodeTime));
|
||||
flowPath->profile()->setExecutionTime(GraphProfile::relativeTime(timeStart));
|
||||
//flowPath->profile().printOut();
|
||||
|
@ -491,7 +491,7 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace)
|
|||
if (__variableSpace->launchContext()->getWorkspace() != nullptr) {
|
||||
auto m = __variableSpace->launchContext()->getWorkspace()->getAllocatedSize();
|
||||
auto h = graph->hashCode();
|
||||
sd::memory::MemoryRegistrator::getInstance()->setGraphMemoryFootprintIfGreater(h, m);
|
||||
sd::memory::MemoryRegistrator::getInstance().setGraphMemoryFootprintIfGreater(h, m);
|
||||
}
|
||||
|
||||
if (tempFlow) {
|
||||
|
@ -523,7 +523,7 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace)
|
|||
// converting FlatGraph to internal representation
|
||||
auto nativeGraph = new Graph(restoredGraph);
|
||||
|
||||
if (Environment::getInstance()->isDebugAndVerbose()) {
|
||||
if (Environment::getInstance().isDebugAndVerbose()) {
|
||||
nativeGraph->printOut();
|
||||
}
|
||||
|
||||
|
@ -742,7 +742,7 @@ Graph* GraphExecutioner::importFromTensorFlow(const char *fileName) {
|
|||
nd4j_verbose("Node id: [%i]; name: [%s]; opName: [%s]\n", n + 1, node.name().c_str(),
|
||||
node.op().c_str());
|
||||
|
||||
sd::ops::DeclarableOp *op = sd::ops::OpRegistrator::getInstance()->getOperationFloat(node.op().c_str());
|
||||
sd::ops::DeclarableOp *op = sd::ops::OpRegistrator::getInstance().getOperationFloat(node.op().c_str());
|
||||
|
||||
if (op == nullptr) {
|
||||
nd4j_verbose("Op wasn't found: %s\n", node.op().c_str());
|
||||
|
@ -859,7 +859,7 @@ flatbuffers::Offset<FlatResult> GraphExecutioner::execute(Graph *graph, flatbuff
|
|||
}
|
||||
}
|
||||
|
||||
if (Environment::getInstance()->isDebugAndVerbose())
|
||||
if (Environment::getInstance().isDebugAndVerbose())
|
||||
graph->printOut();
|
||||
|
||||
auto status = GraphExecutioner::execute(graph);
|
||||
|
|
|
@ -25,11 +25,9 @@
|
|||
|
||||
namespace sd {
|
||||
namespace graph {
|
||||
GraphHolder* GraphHolder::getInstance() {
|
||||
if (_INSTANCE == 0)
|
||||
_INSTANCE = new GraphHolder();
|
||||
|
||||
return _INSTANCE;
|
||||
GraphHolder& GraphHolder::getInstance() {
|
||||
static GraphHolder instance;
|
||||
return instance;
|
||||
};
|
||||
|
||||
void GraphHolder::registerGraph(Nd4jLong graphId, Graph* graph) {
|
||||
|
@ -126,7 +124,5 @@ namespace sd {
|
|||
|
||||
return res;
|
||||
}
|
||||
|
||||
GraphHolder* GraphHolder::_INSTANCE = 0;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -636,7 +636,7 @@ namespace sd {
|
|||
block->setOpDescriptor(this->getCustomOp()->getOpDescriptor());
|
||||
}
|
||||
} else if (this->_opType == OpType_CUSTOM) {
|
||||
auto op = sd::ops::OpRegistrator::getInstance()->getOperation(this->opNum());
|
||||
auto op = sd::ops::OpRegistrator::getInstance().getOperation(this->opNum());
|
||||
if (op == nullptr) {
|
||||
nd4j_verbose("Can't find operation: %lld\n", this->opNum());
|
||||
throw std::runtime_error("Can't find requested operation");
|
||||
|
|
|
@ -364,8 +364,6 @@ namespace sd {
|
|||
|
||||
class BlasHelper {
|
||||
private:
|
||||
static BlasHelper* _instance;
|
||||
|
||||
bool _hasHgemv = false;
|
||||
bool _hasHgemm = false;
|
||||
bool _hasHgemmBatch = false;
|
||||
|
@ -404,7 +402,7 @@ namespace sd {
|
|||
CusolverDnDgesvd cusolverDnDgesvd;
|
||||
|
||||
public:
|
||||
static BlasHelper* getInstance();
|
||||
static BlasHelper& getInstance();
|
||||
|
||||
void initializeFunctions(Nd4jPointer *functions);
|
||||
void initializeDeviceFunctions(Nd4jPointer *functions);
|
||||
|
|
|
@ -35,7 +35,6 @@
|
|||
namespace sd {
|
||||
class ND4J_EXPORT ConstantHelper {
|
||||
private:
|
||||
static ConstantHelper* _INSTANCE;
|
||||
ConstantHelper();
|
||||
|
||||
std::vector<MAP_IMPL<ConstantDescriptor, ConstantHolder*>> _cache;
|
||||
|
@ -48,9 +47,9 @@ namespace sd {
|
|||
|
||||
std::vector<Nd4jLong> _counters;
|
||||
public:
|
||||
~ConstantHelper() = default;
|
||||
~ConstantHelper();
|
||||
|
||||
static ConstantHelper* getInstance();
|
||||
static ConstantHelper& getInstance();
|
||||
static int getCurrentDevice();
|
||||
static int getNumberOfDevices();
|
||||
void* replicatePointer(void *src, size_t numBytes, memory::Workspace *workspace = nullptr);
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
#include <mutex>
|
||||
#include <vector>
|
||||
#include <array/ShapeDescriptor.h>
|
||||
#include <array/ConstantDataBuffer.h>
|
||||
#include <array/ConstantShapeBuffer.h>
|
||||
#include <memory/Workspace.h>
|
||||
#include <system/op_boilerplate.h>
|
||||
|
||||
|
@ -35,24 +35,22 @@ namespace sd {
|
|||
|
||||
class ND4J_EXPORT ConstantShapeHelper {
|
||||
private:
|
||||
static ConstantShapeHelper *_INSTANCE;
|
||||
|
||||
std::mutex _mutex;
|
||||
std::vector<MAP_IMPL<ShapeDescriptor, ConstantDataBuffer>> _cache;
|
||||
std::vector<MAP_IMPL<ShapeDescriptor, ConstantShapeBuffer>> _cache;
|
||||
|
||||
|
||||
ConstantShapeHelper();
|
||||
public:
|
||||
~ConstantShapeHelper() = default;
|
||||
|
||||
static ConstantShapeHelper* getInstance();
|
||||
static ConstantShapeHelper & getInstance();
|
||||
|
||||
|
||||
ConstantDataBuffer bufferForShapeInfo(sd::DataType dataType, char order, const std::vector<Nd4jLong> &shape);
|
||||
ConstantDataBuffer bufferForShapeInfo(const ShapeDescriptor &descriptor);
|
||||
ConstantDataBuffer bufferForShapeInfo(const Nd4jLong *shapeInfo);
|
||||
ConstantDataBuffer bufferForShapeInfo(sd::DataType dataType, char order, int rank, const Nd4jLong* shape);
|
||||
ConstantDataBuffer createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace = nullptr, const std::vector<int> &dimensions = {});
|
||||
ConstantShapeBuffer& bufferForShapeInfo(sd::DataType dataType, char order, const std::vector<Nd4jLong> &shape);
|
||||
ConstantShapeBuffer& bufferForShapeInfo(const ShapeDescriptor &descriptor);
|
||||
ConstantShapeBuffer& bufferForShapeInfo(const Nd4jLong *shapeInfo);
|
||||
ConstantShapeBuffer& bufferForShapeInfo(sd::DataType dataType, char order, int rank, const Nd4jLong* shape);
|
||||
ConstantShapeBuffer& createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace = nullptr, const std::vector<int> &dimensions = {});
|
||||
|
||||
|
||||
const Nd4jLong* emptyShapeInfo(sd::DataType dataType);
|
||||
|
|
|
@ -35,8 +35,6 @@
|
|||
namespace sd {
|
||||
class ND4J_EXPORT ConstantTadHelper {
|
||||
private:
|
||||
static ConstantTadHelper *_INSTANCE;
|
||||
|
||||
std::mutex _mutex;
|
||||
std::vector<MAP_IMPL<TadDescriptor, TadPack>> _cache;
|
||||
|
||||
|
@ -44,7 +42,7 @@ namespace sd {
|
|||
public:
|
||||
~ConstantTadHelper() = default;
|
||||
|
||||
static ConstantTadHelper* getInstance();
|
||||
static ConstantTadHelper & getInstance();
|
||||
|
||||
/**
|
||||
* These methods calculate Tensor-Along-Dimension(s) shape and offsets
|
||||
|
|
|
@ -44,7 +44,7 @@ namespace sd {
|
|||
// cuda-specific debug functions
|
||||
#ifdef __CUDACC__
|
||||
static FORCEINLINE void checkErrorCode(cudaStream_t *stream, int opNum = 0) {
|
||||
if (Environment::getInstance()->isDebug()) {
|
||||
if (Environment::getInstance().isDebug()) {
|
||||
cudaError_t res = cudaStreamSynchronize(*stream);
|
||||
|
||||
if (res != 0) {
|
||||
|
|
|
@ -206,7 +206,7 @@ LoopKind::Kind LoopKind::deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const
|
|||
const bool tVectorOrC = shape::isCommonVector(tadShapeInfo, temp) || tOrder == 'c';
|
||||
const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c';;
|
||||
|
||||
if(shape::length(tadShapeInfo) * shape::length(zShapeInfo) <= Environment::getInstance()->elementwiseThreshold() && xEws == 1 && xOrder == 'c' && xRank == 2 &&
|
||||
if(shape::length(tadShapeInfo) * shape::length(zShapeInfo) <= Environment::getInstance().elementwiseThreshold() && xEws == 1 && xOrder == 'c' && xRank == 2 &&
|
||||
tEws > 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC)))
|
||||
return SMALLARR2DX;
|
||||
if(tEws == 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC)))
|
||||
|
|
|
@ -702,21 +702,21 @@ namespace sd {
|
|||
std::vector<Nd4jLong> zeroOffsets;
|
||||
|
||||
if (xLen == yLen) {
|
||||
tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dims, dimsLen);
|
||||
tadPackY = sd::ConstantTadHelper::getInstance()->tadForDimensions(yShapeInfo, dims, dimsLen);
|
||||
tadPackX = sd::ConstantTadHelper::getInstance().tadForDimensions(xShapeInfo, dims, dimsLen);
|
||||
tadPackY = sd::ConstantTadHelper::getInstance().tadForDimensions(yShapeInfo, dims, dimsLen);
|
||||
xTadShapeInfo = tadPackX.primaryShapeInfo();
|
||||
yTadShapeInfo = tadPackY.primaryShapeInfo();
|
||||
xTadOffsets = tadPackX.primaryOffsets();
|
||||
yTadOffsets = tadPackY.primaryOffsets();
|
||||
}
|
||||
else if (yLen > xLen) {
|
||||
tadPackY = sd::ConstantTadHelper::getInstance()->tadForDimensions(yShapeInfo, dims, dimsLen);
|
||||
tadPackY = sd::ConstantTadHelper::getInstance().tadForDimensions(yShapeInfo, dims, dimsLen);
|
||||
xTadShapeInfo = xShapeInfo;
|
||||
yTadShapeInfo = tadPackY.primaryShapeInfo();
|
||||
yTadOffsets = tadPackY.primaryOffsets();
|
||||
}
|
||||
else {
|
||||
tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dims, dimsLen);
|
||||
tadPackX = sd::ConstantTadHelper::getInstance().tadForDimensions(xShapeInfo, dims, dimsLen);
|
||||
yTadShapeInfo = yShapeInfo;
|
||||
xTadShapeInfo = tadPackX.primaryShapeInfo();
|
||||
xTadOffsets = tadPackX.primaryOffsets();
|
||||
|
|
|
@ -32,8 +32,6 @@
|
|||
namespace sd {
|
||||
class ND4J_EXPORT OpTracker {
|
||||
private:
|
||||
static OpTracker* _INSTANCE;
|
||||
|
||||
std::string _export;
|
||||
|
||||
int _operations = 0;
|
||||
|
@ -45,7 +43,7 @@ namespace sd {
|
|||
template <typename T>
|
||||
std::string local_to_string(T value);
|
||||
public:
|
||||
static OpTracker* getInstance();
|
||||
static OpTracker& getInstance();
|
||||
|
||||
int totalGroups();
|
||||
int totalOperations();
|
||||
|
|
|
@ -69,14 +69,14 @@ namespace sd {
|
|||
void executeOnce() override {
|
||||
PointersManager manager(LaunchContext::defaultContext(), "BroadcastBM");
|
||||
|
||||
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(_x->shapeInfo(), _axis);
|
||||
auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(_z->shapeInfo(), _axis);
|
||||
auto packX = ConstantTadHelper::getInstance().tadForDimensions(_x->shapeInfo(), _axis);
|
||||
auto packZ = ConstantTadHelper::getInstance().tadForDimensions(_z->shapeInfo(), _axis);
|
||||
|
||||
auto tadOnlyShapeInfo = Environment::getInstance()->isCPU() ? packX.primaryShapeInfo() : packX.specialShapeInfo();
|
||||
auto tadOffsets = Environment::getInstance()->isCPU() ? packX.primaryOffsets() : packX.specialOffsets();
|
||||
auto tadOnlyShapeInfo = Environment::getInstance().isCPU() ? packX.primaryShapeInfo() : packX.specialShapeInfo();
|
||||
auto tadOffsets = Environment::getInstance().isCPU() ? packX.primaryOffsets() : packX.specialOffsets();
|
||||
|
||||
auto tadOnlyShapeInfoZ = Environment::getInstance()->isCPU() ? packZ.primaryShapeInfo() : packZ.specialShapeInfo();
|
||||
auto tadOffsetsZ = Environment::getInstance()->isCPU() ? packZ.primaryOffsets() : packZ.specialOffsets();
|
||||
auto tadOnlyShapeInfoZ = Environment::getInstance().isCPU() ? packZ.primaryShapeInfo() : packZ.specialShapeInfo();
|
||||
auto tadOffsetsZ = Environment::getInstance().isCPU() ? packZ.primaryOffsets() : packZ.specialOffsets();
|
||||
|
||||
NativeOpExecutioner::execBroadcast(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), _y->buffer(), _y->shapeInfo(), _y->specialBuffer(), _y->specialShapeInfo(), _z->buffer(), _z->shapeInfo(), _z->specialBuffer(), _z->specialShapeInfo(), nullptr, _axis.size(),
|
||||
/*Nd4jLong **/ tadOnlyShapeInfo, /*Nd4jLong */ tadOffsets, /*Nd4jLong */ tadOnlyShapeInfoZ, /*Nd4jLong */ tadOffsetsZ);
|
||||
|
|
|
@ -36,7 +36,7 @@ namespace sd {
|
|||
sd::graph::Context *_context = nullptr;
|
||||
public:
|
||||
DeclarableBenchmark(sd::ops::DeclarableOp &op, std::string name = 0) : OpBenchmark() {
|
||||
_op = &op; //ops::OpRegistrator::getInstance()->getOperation(op.getOpHash());
|
||||
_op = &op; //ops::OpRegistrator::getInstance().getOperation(op.getOpHash());
|
||||
_testName = name;
|
||||
}
|
||||
|
||||
|
|
|
@ -88,10 +88,10 @@ namespace sd {
|
|||
else
|
||||
NativeOpExecutioner::execReduceSameScalar(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), nullptr, _z->buffer(), _z->shapeInfo(), _z->specialBuffer(), _z->specialShapeInfo());
|
||||
else {
|
||||
auto pack = ConstantTadHelper::getInstance()->tadForDimensions(_x->shapeInfo(), _axis);
|
||||
auto pack = ConstantTadHelper::getInstance().tadForDimensions(_x->shapeInfo(), _axis);
|
||||
|
||||
auto tadOnlyShapeInfo = Environment::getInstance()->isCPU() ? pack.primaryShapeInfo() : pack.specialShapeInfo();
|
||||
auto tadOffsets = Environment::getInstance()->isCPU() ? pack.primaryOffsets() : pack.specialOffsets();
|
||||
auto tadOnlyShapeInfo = Environment::getInstance().isCPU() ? pack.primaryShapeInfo() : pack.specialShapeInfo();
|
||||
auto tadOffsets = Environment::getInstance().isCPU() ? pack.primaryOffsets() : pack.specialOffsets();
|
||||
|
||||
if (_opType == 0)
|
||||
NativeOpExecutioner::execReduceFloat(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), nullptr, _z->buffer(), _z->shapeInfo(), _z->specialBuffer(), _z->specialShapeInfo(), nullptr, _axis.size(), tadOnlyShapeInfo, tadOffsets);
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include <loops/type_conversions.h>
|
||||
#include <system/type_boilerplate.h>
|
||||
#include <cstring>
|
||||
#include <array/PrimaryPointerDeallocator.h>
|
||||
|
||||
namespace sd {
|
||||
|
||||
|
@ -42,11 +43,17 @@ namespace sd {
|
|||
}
|
||||
}
|
||||
|
||||
ConstantHelper* ConstantHelper::getInstance() {
|
||||
if (!_INSTANCE)
|
||||
_INSTANCE = new sd::ConstantHelper();
|
||||
ConstantHelper::~ConstantHelper() {
|
||||
for (const auto &v:_cache) {
|
||||
for (const auto &c:v) {
|
||||
delete c.second;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return _INSTANCE;
|
||||
ConstantHelper& ConstantHelper::getInstance() {
|
||||
static ConstantHelper instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void* ConstantHelper::replicatePointer(void *src, size_t numBytes, memory::Workspace *workspace) {
|
||||
|
@ -95,17 +102,17 @@ namespace sd {
|
|||
result = holder->getConstantDataBuffer(dataType);
|
||||
else {
|
||||
auto size = descriptor.length() * DataTypeUtils::sizeOf(dataType);
|
||||
auto cbuff = new int8_t[size];
|
||||
auto cbuff = std::make_shared<PointerWrapper>(new int8_t[size], std::make_shared<PrimaryPointerDeallocator>());
|
||||
_counters[deviceId] += size;
|
||||
|
||||
// create buffer with this dtype
|
||||
if (descriptor.isFloat()) {
|
||||
BUILD_DOUBLE_SELECTOR(sd::DataType::DOUBLE, dataType, sd::TypeCast::convertGeneric, (nullptr, const_cast<double *>(descriptor.floatValues().data()), descriptor.length(), cbuff), (sd::DataType::DOUBLE, double), LIBND4J_TYPES);
|
||||
BUILD_DOUBLE_SELECTOR(sd::DataType::DOUBLE, dataType, sd::TypeCast::convertGeneric, (nullptr, const_cast<double *>(descriptor.floatValues().data()), descriptor.length(), cbuff->pointer()), (sd::DataType::DOUBLE, double), LIBND4J_TYPES);
|
||||
} else if (descriptor.isInteger()) {
|
||||
BUILD_DOUBLE_SELECTOR(sd::DataType::INT64, dataType, sd::TypeCast::convertGeneric, (nullptr, const_cast<Nd4jLong *>(descriptor.integerValues().data()), descriptor.length(), cbuff), (sd::DataType::INT64, Nd4jLong), LIBND4J_TYPES);
|
||||
BUILD_DOUBLE_SELECTOR(sd::DataType::INT64, dataType, sd::TypeCast::convertGeneric, (nullptr, const_cast<Nd4jLong *>(descriptor.integerValues().data()), descriptor.length(), cbuff->pointer()), (sd::DataType::INT64, Nd4jLong), LIBND4J_TYPES);
|
||||
}
|
||||
|
||||
ConstantDataBuffer dataBuffer(cbuff, nullptr, descriptor.length(), DataTypeUtils::sizeOf(dataType));
|
||||
ConstantDataBuffer dataBuffer(cbuff, descriptor.length(), dataType);
|
||||
holder->addBuffer(dataBuffer, dataType);
|
||||
|
||||
result = holder->getConstantDataBuffer(dataType);
|
||||
|
@ -122,8 +129,6 @@ namespace sd {
|
|||
else
|
||||
return _counters[deviceId];
|
||||
}
|
||||
|
||||
sd::ConstantHelper* sd::ConstantHelper::_INSTANCE = 0;
|
||||
}
|
||||
|
||||
#endif
|
|
@ -24,51 +24,50 @@
|
|||
#include <helpers/logger.h>
|
||||
#include <helpers/ShapeBuilders.h>
|
||||
#include <helpers/ShapeUtils.h>
|
||||
#include <array/PrimaryPointerDeallocator.h>
|
||||
|
||||
namespace sd {
|
||||
ConstantShapeHelper::ConstantShapeHelper() {
|
||||
_cache.resize(32);
|
||||
for (int e = 0; e < 32; e++) {
|
||||
MAP_IMPL<ShapeDescriptor, ConstantDataBuffer> cache;
|
||||
MAP_IMPL<ShapeDescriptor, ConstantShapeBuffer> cache;
|
||||
_cache[e] = cache;
|
||||
}
|
||||
}
|
||||
|
||||
ConstantShapeHelper* ConstantShapeHelper::getInstance() {
|
||||
if (!_INSTANCE)
|
||||
_INSTANCE = new ConstantShapeHelper();
|
||||
|
||||
return _INSTANCE;
|
||||
ConstantShapeHelper& ConstantShapeHelper::getInstance() {
|
||||
static ConstantShapeHelper instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(sd::DataType dataType, char order, const std::vector<Nd4jLong> &shape) {
|
||||
ConstantShapeBuffer& ConstantShapeHelper::bufferForShapeInfo(sd::DataType dataType, char order, const std::vector<Nd4jLong> &shape) {
|
||||
ShapeDescriptor descriptor(dataType, order, shape);
|
||||
return bufferForShapeInfo(descriptor);
|
||||
}
|
||||
|
||||
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const sd::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
|
||||
ConstantShapeBuffer& ConstantShapeHelper::bufferForShapeInfo(const sd::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
|
||||
ShapeDescriptor descriptor(dataType, order, shape, rank);
|
||||
return bufferForShapeInfo(descriptor);
|
||||
}
|
||||
|
||||
|
||||
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) {
|
||||
int deviceId = 0;
|
||||
ConstantShapeBuffer& ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) {
|
||||
int deviceId = 0;
|
||||
|
||||
std::lock_guard<std::mutex> lock(_mutex);
|
||||
std::lock_guard<std::mutex> lock(_mutex);
|
||||
|
||||
if (_cache[deviceId].count(descriptor) == 0) {
|
||||
auto hPtr = descriptor.toShapeInfo();
|
||||
ConstantDataBuffer buffer(hPtr, nullptr, shape::shapeInfoLength(hPtr)*sizeof(Nd4jLong), DataType::INT64);
|
||||
ShapeDescriptor descriptor1(descriptor);
|
||||
_cache[deviceId][descriptor1] = buffer;
|
||||
return _cache[deviceId][descriptor1];
|
||||
} else {
|
||||
return _cache[deviceId].at(descriptor);
|
||||
}
|
||||
}
|
||||
if (_cache[deviceId].count(descriptor) == 0) {
|
||||
auto hPtr = std::make_shared<PointerWrapper>(descriptor.toShapeInfo(), std::make_shared<PrimaryPointerDeallocator>());
|
||||
ConstantShapeBuffer buffer(hPtr);
|
||||
ShapeDescriptor descriptor1(descriptor);
|
||||
_cache[deviceId][descriptor1] = buffer;
|
||||
return _cache[deviceId][descriptor1];
|
||||
} else {
|
||||
return _cache[deviceId].at(descriptor);
|
||||
}
|
||||
}
|
||||
|
||||
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) {
|
||||
ConstantShapeBuffer& ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) {
|
||||
ShapeDescriptor descriptor(shapeInfo);
|
||||
return bufferForShapeInfo(descriptor);
|
||||
}
|
||||
|
@ -83,7 +82,7 @@ namespace sd {
|
|||
|
||||
const Nd4jLong* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
|
||||
ShapeDescriptor descriptor(dataType, order, shape, rank);
|
||||
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||
return bufferForShapeInfo(descriptor).primary();
|
||||
}
|
||||
|
||||
const Nd4jLong* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const Nd4jLong* shapeInfo) {
|
||||
|
@ -92,26 +91,26 @@ namespace sd {
|
|||
|
||||
const Nd4jLong* ConstantShapeHelper::emptyShapeInfo(const sd::DataType dataType) {
|
||||
auto descriptor = ShapeDescriptor::emptyDescriptor(dataType);
|
||||
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||
return bufferForShapeInfo(descriptor).primary();
|
||||
}
|
||||
|
||||
const Nd4jLong* ConstantShapeHelper::scalarShapeInfo(const sd::DataType dataType) {
|
||||
auto descriptor = ShapeDescriptor::scalarDescriptor(dataType);
|
||||
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||
return bufferForShapeInfo(descriptor).primary();
|
||||
}
|
||||
|
||||
const Nd4jLong* ConstantShapeHelper::vectorShapeInfo(const Nd4jLong length, const sd::DataType dataType) {
|
||||
auto descriptor = ShapeDescriptor::vectorDescriptor(length, dataType);
|
||||
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||
return bufferForShapeInfo(descriptor).primary();
|
||||
}
|
||||
|
||||
const Nd4jLong* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const char order, const std::vector<Nd4jLong> &shape) {
|
||||
ShapeDescriptor descriptor(dataType, order, shape);
|
||||
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||
return bufferForShapeInfo(descriptor).primary();
|
||||
}
|
||||
|
||||
const Nd4jLong* ConstantShapeHelper::createShapeInfo(const ShapeDescriptor &descriptor) {
|
||||
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||
return bufferForShapeInfo(descriptor).primary();
|
||||
}
|
||||
|
||||
const Nd4jLong* ConstantShapeHelper::createFromExisting(Nd4jLong *shapeInfo, bool destroyOriginal) {
|
||||
|
@ -135,7 +134,7 @@ namespace sd {
|
|||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
ConstantDataBuffer ConstantShapeHelper::createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace, const std::vector<int> &dimensions) {
|
||||
ConstantShapeBuffer& ConstantShapeHelper::createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace, const std::vector<int> &dimensions) {
|
||||
|
||||
Nd4jLong* newShapeInfo = nullptr;
|
||||
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(shape::rank(maxShapeInfo)), Nd4jLong);
|
||||
|
@ -185,10 +184,6 @@ ConstantDataBuffer ConstantShapeHelper::createShapeInfoWithUnitiesForBroadcast(c
|
|||
|
||||
return bufferForShapeInfo(descriptor);
|
||||
}
|
||||
|
||||
|
||||
sd::ConstantShapeHelper* sd::ConstantShapeHelper::_INSTANCE = 0;
|
||||
|
||||
}
|
||||
} // namespace sd
|
||||
|
||||
#endif
|
|
@ -21,6 +21,8 @@
|
|||
#include "../ConstantTadHelper.h"
|
||||
#include <helpers/TAD.h>
|
||||
#include <helpers/ShapeUtils.h>
|
||||
#include <array/ConstantOffsetsBuffer.h>
|
||||
#include <array/PrimaryPointerDeallocator.h>
|
||||
|
||||
#ifndef __CUDABLAS__
|
||||
|
||||
|
@ -32,11 +34,9 @@ namespace sd {
|
|||
_cache.emplace_back(pack);
|
||||
}
|
||||
|
||||
ConstantTadHelper* ConstantTadHelper::getInstance() {
|
||||
if (!_INSTANCE)
|
||||
_INSTANCE = new ConstantTadHelper();
|
||||
|
||||
return _INSTANCE;
|
||||
ConstantTadHelper& ConstantTadHelper::getInstance() {
|
||||
static ConstantTadHelper instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) {
|
||||
|
@ -60,60 +60,31 @@ namespace sd {
|
|||
TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) {
|
||||
const int deviceId = 0;
|
||||
|
||||
_mutex.lock();
|
||||
std::lock_guard<std::mutex> lock(_mutex);
|
||||
if (_cache[deviceId].count(descriptor) == 0) {
|
||||
|
||||
// if there's no TadPack matching this descriptor - create one
|
||||
const auto shapeInfo = descriptor.originalShape().toShapeInfo();
|
||||
const int rank = shape::rank(shapeInfo);
|
||||
const std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(rank, descriptor.axis());
|
||||
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(shapeInfo, dimsToExclude);
|
||||
const int subArrRank = (rank == dimsToExclude.size() || descriptor.areUnitiesinShape()) ? rank : rank - dimsToExclude.size();
|
||||
|
||||
auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)]; // shape of sub-arrays (same for all for them)
|
||||
auto oPtr = new Nd4jLong[numOfSubArrs];
|
||||
auto sPtr = std::make_shared<PointerWrapper>(new Nd4jLong[shape::shapeInfoLength(subArrRank)], std::make_shared<PrimaryPointerDeallocator>()); // shape of sub-arrays (same for all for them)
|
||||
auto oPtr = std::make_shared<PointerWrapper>(new Nd4jLong[numOfSubArrs], std::make_shared<PrimaryPointerDeallocator>());
|
||||
|
||||
if (numOfSubArrs > 0)
|
||||
shape::calcSubArrsShapeInfoAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr, oPtr, descriptor.areUnitiesinShape());
|
||||
|
||||
|
||||
ConstantDataBuffer shapesBuffer(sPtr, nullptr, shape::shapeInfoLength(subArrRank)*sizeof(Nd4jLong), DataType::INT64);
|
||||
ConstantDataBuffer offsetsBuffer(oPtr, nullptr, numOfSubArrs*sizeof(Nd4jLong), DataType::INT64);
|
||||
TadPack t(shapesBuffer, offsetsBuffer, numOfSubArrs);
|
||||
|
||||
|
||||
|
||||
// auto shapeInfo = descriptor.originalShape().toShapeInfo();
|
||||
// shape::TAD tad;
|
||||
// tad.init(shapeInfo, descriptor.axis().data(), descriptor.axis().size());
|
||||
// tad.createTadOnlyShapeInfo();
|
||||
// tad.createOffsets();
|
||||
|
||||
// auto sPtr = new Nd4jLong[shape::shapeInfoLength(tad.tadOnlyShapeInfo)];
|
||||
// auto oPtr = new Nd4jLong[tad.numTads];
|
||||
|
||||
// memcpy(sPtr, tad.tadOnlyShapeInfo, shape::shapeInfoByteLength(tad.tadOnlyShapeInfo));
|
||||
// memcpy(oPtr, tad.tadOffsets, tad.numTads * sizeof(Nd4jLong));
|
||||
|
||||
// TadPack t(shapesBuffer, offsetsBuffer, tad.numTads);
|
||||
|
||||
shape::calcSubArrsShapeInfoAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr->pointerAsT<Nd4jLong>(), oPtr->pointerAsT<Nd4jLong>(), descriptor.areUnitiesinShape());
|
||||
|
||||
ConstantShapeBuffer shapeBuffer(sPtr);
|
||||
ConstantOffsetsBuffer offsetsBuffer(oPtr);
|
||||
TadPack t(shapeBuffer, offsetsBuffer, numOfSubArrs);
|
||||
_cache[deviceId][descriptor] = t;
|
||||
|
||||
TadPack &r = _cache[deviceId][descriptor];
|
||||
_mutex.unlock();
|
||||
|
||||
delete[] shapeInfo;
|
||||
|
||||
return r;
|
||||
} else {
|
||||
TadPack r = _cache[deviceId][descriptor];
|
||||
_mutex.unlock();
|
||||
|
||||
return r;
|
||||
}
|
||||
}
|
||||
|
||||
sd::ConstantTadHelper* sd::ConstantTadHelper::_INSTANCE = 0;
|
||||
return _cache[deviceId][descriptor];
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -162,7 +162,7 @@ static void usualDot(const Nd4jLong length, const double alpha, const void* vX,
|
|||
const bool betaPersent = beta;
|
||||
|
||||
T3 sum = 0;
|
||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(length > Environment::getInstance()->elementwiseThreshold()) schedule(guided) reduction(OMP_SUMT:sum))
|
||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(length > Environment::getInstance().elementwiseThreshold()) schedule(guided) reduction(OMP_SUMT:sum))
|
||||
for(Nd4jLong i = 0; i < length; ++i)
|
||||
sum += X[i * incx] * Y[i * incy];
|
||||
|
||||
|
@ -210,7 +210,7 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, con
|
|||
const auto cType = C->dataType();
|
||||
|
||||
const bool AB(aType == bType), AC(aType == cType), ABC(AB && AC);
|
||||
const bool hasGemm = BlasHelper::getInstance()->hasGEMM(aType);
|
||||
const bool hasGemm = BlasHelper::getInstance().hasGEMM(aType);
|
||||
|
||||
const bool typeDouble = hasGemm && ABC && aType == DataType::DOUBLE;
|
||||
const bool typeFloat = hasGemm && ABC && aType == DataType::FLOAT32;
|
||||
|
@ -261,10 +261,10 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, con
|
|||
const int ldc = (cMcont && cNcont) ? M : !cMcont ? pC->strideAt(0) : pC->strideAt(1);
|
||||
|
||||
if(typeFloat) {
|
||||
BlasHelper::getInstance()->sgemm()(blasOrder, transAblas, transBblas, M, N, K, (float) alpha, pA->bufferAsT<float>(), lda, pB->bufferAsT<float>(), ldb, (float) beta, pC->bufferAsT<float>(), ldc);
|
||||
BlasHelper::getInstance().sgemm()(blasOrder, transAblas, transBblas, M, N, K, (float) alpha, pA->bufferAsT<float>(), lda, pB->bufferAsT<float>(), ldb, (float) beta, pC->bufferAsT<float>(), ldc);
|
||||
}
|
||||
else if(typeDouble) {
|
||||
BlasHelper::getInstance()->dgemm()(blasOrder, transAblas, transBblas, M, N, K, (double) alpha, pA->bufferAsT<double>(), lda, pB->bufferAsT<double>(), ldb, (double) beta, pC->bufferAsT<double>(), ldc);
|
||||
BlasHelper::getInstance().dgemm()(blasOrder, transAblas, transBblas, M, N, K, (double) alpha, pA->bufferAsT<double>(), lda, pB->bufferAsT<double>(), ldb, (double) beta, pC->bufferAsT<double>(), ldc);
|
||||
}
|
||||
|
||||
if(pC != C) {
|
||||
|
@ -321,7 +321,7 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, sd::NDArray* Y,
|
|||
const auto yType = Y->dataType();
|
||||
|
||||
const bool AX(aType == xType), AY(aType == yType), AXY(AX && AY);
|
||||
const bool hasGemv = BlasHelper::getInstance()->hasGEMV(aType);
|
||||
const bool hasGemv = BlasHelper::getInstance().hasGEMV(aType);
|
||||
|
||||
const bool typeDouble = hasGemv && AXY && aType == DataType::DOUBLE;
|
||||
const bool typeFloat = hasGemv && AXY && aType == DataType::FLOAT32;
|
||||
|
@ -347,10 +347,10 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, sd::NDArray* Y,
|
|||
|
||||
// choose appropriate cuda gemm api depending on data types
|
||||
if(typeDouble) {
|
||||
BlasHelper::getInstance()->dgemv()(blasOrder, CblasNoTrans, M, N, alpha, (double*)pA->buffer(), lda, (double*)X->buffer(), incx, beta, (double*)Y->buffer(), incy);
|
||||
BlasHelper::getInstance().dgemv()(blasOrder, CblasNoTrans, M, N, alpha, (double*)pA->buffer(), lda, (double*)X->buffer(), incx, beta, (double*)Y->buffer(), incy);
|
||||
}
|
||||
else if(typeFloat) {
|
||||
BlasHelper::getInstance()->sgemv()(blasOrder, CblasNoTrans, M, N, (float)alpha, (float*)pA->buffer(), lda, (float*)X->buffer(), incx, (float)beta, (float*)Y->buffer(), incy);
|
||||
BlasHelper::getInstance().sgemv()(blasOrder, CblasNoTrans, M, N, (float)alpha, (float*)pA->buffer(), lda, (float*)X->buffer(), incx, (float)beta, (float*)Y->buffer(), incy);
|
||||
}
|
||||
|
||||
if(pA != A)
|
||||
|
@ -617,7 +617,7 @@ static void usualGemm(const char cOrder, const bool transA, const bool transB, c
|
|||
const bool flagA = (flagC && transA) || (!flagC && !transA);
|
||||
const bool flagB = (flagC && transB) || (!flagC && !transB);
|
||||
|
||||
// PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(M*N > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
|
||||
// PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(M*N > Environment::getInstance().elementwiseThreshold()) schedule(guided))
|
||||
// for(uint row = 0; row < M; ++row) {
|
||||
|
||||
// T3* c = flagC ? (C + row) : (C + row * ldc);
|
||||
|
|
|
@ -37,11 +37,9 @@ namespace sd {
|
|||
|
||||
}
|
||||
|
||||
CublasHelper* CublasHelper::getInstance() {
|
||||
if (!_INSTANCE)
|
||||
_INSTANCE = new sd::CublasHelper();
|
||||
|
||||
return _INSTANCE;
|
||||
CublasHelper& CublasHelper::getInstance() {
|
||||
static CublasHelper instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void* CublasHelper::handle() {
|
||||
|
@ -55,7 +53,4 @@ namespace sd {
|
|||
void* CublasHelper::handle(int deviceId) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
||||
sd::CublasHelper* sd::CublasHelper::_INSTANCE = 0;
|
||||
}
|
|
@ -29,7 +29,6 @@
|
|||
namespace sd {
|
||||
class ND4J_EXPORT CublasHelper {
|
||||
private:
|
||||
static CublasHelper *_INSTANCE;
|
||||
static std::mutex _mutex;
|
||||
|
||||
std::vector<void*> _cache;
|
||||
|
@ -37,9 +36,9 @@ namespace sd {
|
|||
std::vector<void*> _cudnn;
|
||||
|
||||
CublasHelper();
|
||||
~CublasHelper();
|
||||
public:
|
||||
static CublasHelper* getInstance();
|
||||
~CublasHelper();
|
||||
static CublasHelper& getInstance();
|
||||
|
||||
void* cudnn();
|
||||
void* solver();
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
#include <cuda_runtime.h>
|
||||
#include <cuda.h>
|
||||
#include <execution/AffinityManager.h>
|
||||
#include <array/PrimaryPointerDeallocator.h>
|
||||
|
||||
#define CONSTANT_LIMIT 49152
|
||||
|
||||
|
@ -84,11 +85,17 @@ namespace sd {
|
|||
throw cuda_exception::build("Final cudaSetDevice failed", res);
|
||||
}
|
||||
|
||||
ConstantHelper* ConstantHelper::getInstance() {
|
||||
if (!_INSTANCE)
|
||||
_INSTANCE = new sd::ConstantHelper();
|
||||
ConstantHelper::~ConstantHelper() {
|
||||
for (const auto &v:_cache) {
|
||||
for (const auto &c:v) {
|
||||
delete c.second;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return _INSTANCE;
|
||||
ConstantHelper& ConstantHelper::getInstance() {
|
||||
static ConstantHelper instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void* ConstantHelper::replicatePointer(void *src, size_t numBytes, memory::Workspace *workspace) {
|
||||
|
@ -156,19 +163,21 @@ namespace sd {
|
|||
result = holder->getConstantDataBuffer(dataType);
|
||||
} else {
|
||||
auto numBytes = descriptor.length() * DataTypeUtils::sizeOf(dataType);
|
||||
auto cbuff = new int8_t[numBytes];
|
||||
auto cbuff = std::make_shared<PointerWrapper>(new int8_t[numBytes], std::make_shared<PrimaryPointerDeallocator>());
|
||||
_counters[deviceId] += numBytes;
|
||||
|
||||
// create buffer with this dtype
|
||||
if (descriptor.isFloat()) {
|
||||
BUILD_DOUBLE_SELECTOR(sd::DataType::DOUBLE, dataType, sd::SpecialTypeConverter::convertGeneric, (nullptr, const_cast<double *>(descriptor.floatValues().data()), descriptor.length(), cbuff), (sd::DataType::DOUBLE, double), LIBND4J_TYPES);
|
||||
BUILD_DOUBLE_SELECTOR(sd::DataType::DOUBLE, dataType, sd::SpecialTypeConverter::convertGeneric, (nullptr, const_cast<double *>(descriptor.floatValues().data()), descriptor.length(), cbuff->pointer()), (sd::DataType::DOUBLE, double), LIBND4J_TYPES);
|
||||
} else if (descriptor.isInteger()) {
|
||||
BUILD_DOUBLE_SELECTOR(sd::DataType::INT64, dataType, sd::SpecialTypeConverter::convertGeneric, (nullptr, const_cast<Nd4jLong *>(descriptor.integerValues().data()), descriptor.length(), cbuff), (sd::DataType::INT64, Nd4jLong), LIBND4J_TYPES);
|
||||
BUILD_DOUBLE_SELECTOR(sd::DataType::INT64, dataType, sd::SpecialTypeConverter::convertGeneric, (nullptr, const_cast<Nd4jLong *>(descriptor.integerValues().data()), descriptor.length(), cbuff->pointer()), (sd::DataType::INT64, Nd4jLong), LIBND4J_TYPES);
|
||||
}
|
||||
|
||||
auto dbuff = replicatePointer(cbuff, descriptor.length() * DataTypeUtils::sizeOf(dataType));
|
||||
// we don't have deallocator here.
|
||||
// TODO: we probably want to make use deallocator here, if we're not using constant memory
|
||||
auto dbuff = std::make_shared<PointerWrapper>(replicatePointer(cbuff->pointer(), descriptor.length() * DataTypeUtils::sizeOf(dataType)));
|
||||
|
||||
ConstantDataBuffer dataBuffer(cbuff, dbuff, descriptor.length(), DataTypeUtils::sizeOf(dataType));
|
||||
ConstantDataBuffer dataBuffer(cbuff, dbuff, descriptor.length(), dataType);
|
||||
|
||||
holder->addBuffer(dataBuffer, dataType);
|
||||
result = holder->getConstantDataBuffer(dataType);
|
||||
|
@ -184,6 +193,4 @@ namespace sd {
|
|||
else
|
||||
return _counters[deviceId];
|
||||
}
|
||||
|
||||
sd::ConstantHelper* sd::ConstantHelper::_INSTANCE = 0;
|
||||
}
|
|
@ -24,6 +24,8 @@
|
|||
#include <helpers/ShapeBuilders.h>
|
||||
#include <execution/AffinityManager.h>
|
||||
#include <helpers/ConstantHelper.h>
|
||||
#include <array/PrimaryPointerDeallocator.h>
|
||||
#include <array/CudaPointerDeallocator.h>
|
||||
|
||||
namespace sd {
|
||||
|
||||
|
@ -32,46 +34,44 @@ namespace sd {
|
|||
|
||||
_cache.resize(numDevices);
|
||||
for (int e = 0; e < numDevices; e++) {
|
||||
MAP_IMPL<ShapeDescriptor, ConstantDataBuffer> cache;
|
||||
MAP_IMPL<ShapeDescriptor, ConstantShapeBuffer> cache;
|
||||
_cache[e] = cache;
|
||||
}
|
||||
}
|
||||
|
||||
ConstantShapeHelper* ConstantShapeHelper::getInstance() {
|
||||
if (!_INSTANCE)
|
||||
_INSTANCE = new ConstantShapeHelper();
|
||||
|
||||
return _INSTANCE;
|
||||
ConstantShapeHelper& ConstantShapeHelper::getInstance() {
|
||||
static ConstantShapeHelper instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(sd::DataType dataType, char order, const std::vector<Nd4jLong> &shape) {
|
||||
ConstantShapeBuffer& ConstantShapeHelper::bufferForShapeInfo(sd::DataType dataType, char order, const std::vector<Nd4jLong> &shape) {
|
||||
ShapeDescriptor descriptor(dataType, order, shape);
|
||||
return bufferForShapeInfo(descriptor);
|
||||
}
|
||||
|
||||
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const sd::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
|
||||
ConstantShapeBuffer& ConstantShapeHelper::bufferForShapeInfo(const sd::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
|
||||
ShapeDescriptor descriptor(dataType, order, shape, rank);
|
||||
return bufferForShapeInfo(descriptor);
|
||||
}
|
||||
|
||||
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) {
|
||||
ConstantShapeBuffer& ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) {
|
||||
int deviceId = AffinityManager::currentDeviceId();
|
||||
|
||||
std::lock_guard<std::mutex> lock(_mutex);
|
||||
|
||||
if (_cache[deviceId].count(descriptor) == 0) {
|
||||
auto hPtr = descriptor.toShapeInfo();
|
||||
auto dPtr = ConstantHelper::getInstance()->replicatePointer(hPtr, shape::shapeInfoByteLength(hPtr));
|
||||
ConstantDataBuffer buffer(hPtr, dPtr, shape::shapeInfoLength(hPtr) * sizeof(Nd4jLong), DataType::INT64);
|
||||
ShapeDescriptor descriptor1(descriptor);
|
||||
_cache[deviceId][descriptor1] = buffer;
|
||||
return _cache[deviceId][descriptor1];
|
||||
auto hPtr = std::make_shared<PointerWrapper>(descriptor.toShapeInfo(), std::make_shared<PrimaryPointerDeallocator>());
|
||||
auto dPtr = std::make_shared<PointerWrapper>(ConstantHelper::getInstance().replicatePointer(hPtr->pointer(), shape::shapeInfoByteLength(hPtr->pointerAsT<Nd4jLong>())), std::make_shared<CudaPointerDeallocator>());
|
||||
ConstantShapeBuffer buffer(hPtr, dPtr);
|
||||
ShapeDescriptor descriptor1(descriptor);
|
||||
_cache[deviceId][descriptor1] = buffer;
|
||||
return _cache[deviceId][descriptor1];
|
||||
} else {
|
||||
return _cache[deviceId].at(descriptor);
|
||||
return _cache[deviceId].at(descriptor);
|
||||
}
|
||||
}
|
||||
|
||||
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) {
|
||||
ConstantShapeBuffer& ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) {
|
||||
ShapeDescriptor descriptor(shapeInfo);
|
||||
return bufferForShapeInfo(descriptor);
|
||||
}
|
||||
|
@ -85,7 +85,7 @@ namespace sd {
|
|||
|
||||
Nd4jLong const* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
|
||||
ShapeDescriptor descriptor(dataType, order, shape, rank);
|
||||
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||
return bufferForShapeInfo(descriptor).primary();
|
||||
}
|
||||
|
||||
Nd4jLong const* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const Nd4jLong* shapeInfo) {
|
||||
|
@ -94,26 +94,26 @@ namespace sd {
|
|||
|
||||
Nd4jLong const* ConstantShapeHelper::emptyShapeInfo(const sd::DataType dataType) {
|
||||
auto descriptor = ShapeDescriptor::emptyDescriptor(dataType);
|
||||
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||
return bufferForShapeInfo(descriptor).primary();
|
||||
}
|
||||
|
||||
Nd4jLong const* ConstantShapeHelper::scalarShapeInfo(const sd::DataType dataType) {
|
||||
auto descriptor = ShapeDescriptor::scalarDescriptor(dataType);
|
||||
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||
return bufferForShapeInfo(descriptor).primary();
|
||||
}
|
||||
|
||||
Nd4jLong const* ConstantShapeHelper::vectorShapeInfo(const Nd4jLong length, const sd::DataType dataType) {
|
||||
auto descriptor = ShapeDescriptor::vectorDescriptor(length, dataType);
|
||||
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||
return bufferForShapeInfo(descriptor).primary();
|
||||
}
|
||||
|
||||
Nd4jLong const* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const char order, const std::vector<Nd4jLong> &shape) {
|
||||
ShapeDescriptor descriptor(dataType, order, shape);
|
||||
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||
return bufferForShapeInfo(descriptor).primary();
|
||||
}
|
||||
|
||||
Nd4jLong const* ConstantShapeHelper::createShapeInfo(const ShapeDescriptor &descriptor) {
|
||||
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||
return bufferForShapeInfo(descriptor).primary();
|
||||
}
|
||||
|
||||
Nd4jLong const* ConstantShapeHelper::createFromExisting(Nd4jLong *shapeInfo, bool destroyOriginal) {
|
||||
|
@ -136,7 +136,7 @@ namespace sd {
|
|||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
ConstantDataBuffer ConstantShapeHelper::createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace, const std::vector<int>& dimensions) {
|
||||
ConstantShapeBuffer& ConstantShapeHelper::createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace, const std::vector<int>& dimensions) {
|
||||
|
||||
Nd4jLong* newShapeInfo = nullptr;
|
||||
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(shape::rank(maxShapeInfo)), Nd4jLong);
|
||||
|
@ -187,7 +187,4 @@ ConstantDataBuffer ConstantShapeHelper::createShapeInfoWithUnitiesForBroadcast(c
|
|||
return bufferForShapeInfo(descriptor);
|
||||
}
|
||||
|
||||
|
||||
sd::ConstantShapeHelper* sd::ConstantShapeHelper::_INSTANCE = 0;
|
||||
|
||||
}
|
|
@ -25,6 +25,8 @@
|
|||
#include <exceptions/cuda_exception.h>
|
||||
#include <execution/LaunchContext.h>
|
||||
#include <helpers/ShapeUtils.h>
|
||||
#include <array/PrimaryPointerDeallocator.h>
|
||||
#include <array/CudaPointerDeallocator.h>
|
||||
|
||||
namespace sd {
|
||||
ConstantTadHelper::ConstantTadHelper() {
|
||||
|
@ -36,11 +38,9 @@ namespace sd {
|
|||
}
|
||||
}
|
||||
|
||||
ConstantTadHelper* ConstantTadHelper::getInstance() {
|
||||
if (!_INSTANCE)
|
||||
_INSTANCE = new ConstantTadHelper();
|
||||
|
||||
return _INSTANCE;
|
||||
ConstantTadHelper& ConstantTadHelper::getInstance() {
|
||||
static ConstantTadHelper instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) {
|
||||
|
@ -73,25 +73,28 @@ namespace sd {
|
|||
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(shapeInfo, dimsToExclude);
|
||||
const int subArrRank = (rank == dimsToExclude.size() || descriptor.areUnitiesinShape()) ? rank : rank - dimsToExclude.size();
|
||||
|
||||
auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)];
|
||||
auto oPtr = new Nd4jLong[numOfSubArrs];
|
||||
auto sPtr = std::make_shared<PointerWrapper>(new Nd4jLong[shape::shapeInfoLength(subArrRank)], std::make_shared<PrimaryPointerDeallocator>());
|
||||
auto oPtr = std::make_shared<PointerWrapper>(new Nd4jLong[numOfSubArrs], std::make_shared<PrimaryPointerDeallocator>());
|
||||
|
||||
if (numOfSubArrs > 0)
|
||||
shape::calcSubArrsShapeInfoAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr, oPtr, descriptor.areUnitiesinShape());
|
||||
shape::calcSubArrsShapeInfoAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr->pointerAsT<Nd4jLong>(), oPtr->pointerAsT<Nd4jLong>(), descriptor.areUnitiesinShape());
|
||||
|
||||
Nd4jPointer soPtr;
|
||||
auto res = cudaMalloc(reinterpret_cast<void**>(&soPtr), numOfSubArrs * sizeof(Nd4jLong));
|
||||
if (res != 0)
|
||||
throw cuda_exception::build("Memory allocation for tadOffsets failed", res);
|
||||
|
||||
res = cudaMemcpy(soPtr, oPtr, numOfSubArrs * sizeof(Nd4jLong), cudaMemcpyHostToDevice);
|
||||
res = cudaMemcpy(soPtr, oPtr->pointer(), numOfSubArrs * sizeof(Nd4jLong), cudaMemcpyHostToDevice);
|
||||
if (res != 0)
|
||||
throw cuda_exception::build("tadOffsets copy failed", res);
|
||||
|
||||
auto ssPtr = ConstantHelper::getInstance()->replicatePointer(sPtr, shape::shapeInfoByteLength(subArrRank));
|
||||
// TODO: add deallocator here?
|
||||
auto ssPtr = std::make_shared<PointerWrapper>(ConstantHelper::getInstance().replicatePointer(sPtr->pointer(), shape::shapeInfoByteLength(subArrRank)));
|
||||
|
||||
ConstantDataBuffer shapesBuffer(sPtr, ssPtr, shape::shapeInfoLength(subArrRank) * sizeof(Nd4jLong), DataType::INT64);
|
||||
ConstantDataBuffer offsetsBuffer(oPtr, soPtr, numOfSubArrs * sizeof(Nd4jLong), DataType::INT64);
|
||||
|
||||
|
||||
ConstantShapeBuffer shapesBuffer(sPtr, ssPtr);
|
||||
ConstantOffsetsBuffer offsetsBuffer(oPtr, std::make_shared<PointerWrapper>(soPtr, std::make_shared<CudaPointerDeallocator>()));
|
||||
|
||||
TadPack t(shapesBuffer, offsetsBuffer, numOfSubArrs);
|
||||
_cache[deviceId][descriptor] = t;
|
||||
|
@ -107,6 +110,4 @@ namespace sd {
|
|||
return r;
|
||||
}
|
||||
}
|
||||
|
||||
sd::ConstantTadHelper* sd::ConstantTadHelper::_INSTANCE = 0;
|
||||
}
|
|
@ -238,7 +238,7 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
|
|||
if (C->isEmpty())
|
||||
return C;
|
||||
|
||||
const int major = Environment::getInstance()->capabilities()[AffinityManager::currentDeviceId()].first();
|
||||
const int major = Environment::getInstance().capabilities()[AffinityManager::currentDeviceId()].first();
|
||||
|
||||
const auto aType = A->dataType();
|
||||
const auto bType = B->dataType();
|
||||
|
@ -268,7 +268,7 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
|
|||
const int sharedMem = threadsPerBlock * sizeof(int) * 6 + 128; // 6 = aRank + bRank + cRank
|
||||
|
||||
NDArray::prepareSpecialUse({C}, {A, B});
|
||||
// BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, sharedMem, stream, A->specialBuffer(), A->specialShapeInfo(), B->specialBuffer(), B->specialShapeInfo(), C->specialBuffer(), C->specialShapeInfo(), 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
|
||||
// BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, sharedMem, stream, A->specialBuffer(), A->specialShapeInfo(), B->specialBuffer(), B->specialShapeInfo(), C->specialBuffer(), C->special(), 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (blocksPerGrid, threadsPerBlock, sharedMem, stream, A->specialBuffer(), A->specialShapeInfo(), B->specialBuffer(), B->specialShapeInfo(), C->specialBuffer(), C->specialShapeInfo(), 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES)
|
||||
NDArray::registerSpecialUse({C}, {A, B});
|
||||
|
||||
|
@ -411,7 +411,7 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, sd::NDArray* Y,
|
|||
const int blocksPerGrid = (M + threadsPerBlock - 1) / threadsPerBlock;
|
||||
|
||||
NDArray::prepareSpecialUse({Y}, {A, X});
|
||||
// BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, A->specialBuffer(), A->specialShapeInfo(), X->specialBuffer(), X->specialShapeInfo(), Y->specialBuffer(), Y->specialShapeInfo(), incx, incy, 0, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
|
||||
// BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, A->specialBuffer(), A->specialShapeInfo(), X->specialBuffer(), X->specialShapeInfo(), Y->specialBuffer(), Y->special(), incx, incy, 0, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR_THRICE(xType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, A->specialBuffer(), A->specialShapeInfo(), X->specialBuffer(), X->specialShapeInfo(), Y->specialBuffer(), Y->specialShapeInfo(), incx, incy, 0, alpha, beta), NUMERIC_TYPES)
|
||||
NDArray::registerSpecialUse({Y}, {A, X});
|
||||
|
||||
|
@ -667,7 +667,7 @@ NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, con
|
|||
cBatchDims = reinterpret_cast<int*>(manager.replicatePointer(ShapeUtils::evalDimsToExclude(cRank, {cMaxis, cNaxis}).data(), (cRank - 2) * sizeof(int)));
|
||||
|
||||
NDArray::prepareSpecialUse({C}, {A, B});
|
||||
// BUILD_TRIPLE_SELECTOR(A->dataType(), b->dataType(), C->dataType(), batchedGemm, (blocksPerGrid, threadsPerBlock, A->getContext()->getCudaStream(), A->specialBuffer(), A->specialShapeInfo(), B->specialBuffer(), B->specialShapeInfo(), C->specialBuffer(), C->specialShapeInfo(), aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
|
||||
// BUILD_TRIPLE_SELECTOR(A->dataType(), b->dataType(), C->dataType(), batchedGemm, (blocksPerGrid, threadsPerBlock, A->getContext()->getCudaStream(), A->specialBuffer(), A->specialShapeInfo(), B->specialBuffer(), B->specialShapeInfo(), C->specialBuffer(), C->special(), aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR_THRICE(A->dataType(), batchedGemm, (blocksPerGrid, threadsPerBlock, sharedMem, A->getContext()->getCudaStream(), A->specialBuffer(), A->specialShapeInfo(), B->specialBuffer(), B->specialShapeInfo(), C->specialBuffer(), C->specialShapeInfo(), aBatchDims, bBatchDims, cBatchDims, aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta), NUMERIC_TYPES)
|
||||
NDArray::registerSpecialUse({C}, {A, B});
|
||||
|
||||
|
|
|
@ -102,13 +102,9 @@ namespace sd {
|
|||
destroyHandle_(_cache[e]);
|
||||
}
|
||||
|
||||
CublasHelper* CublasHelper::getInstance() {
|
||||
_mutex.lock();
|
||||
if (!_INSTANCE)
|
||||
_INSTANCE = new sd::CublasHelper();
|
||||
_mutex.unlock();
|
||||
|
||||
return _INSTANCE;
|
||||
CublasHelper& CublasHelper::getInstance() {
|
||||
static CublasHelper instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void* CublasHelper::cudnn() {
|
||||
|
@ -138,7 +134,4 @@ namespace sd {
|
|||
|
||||
return _cache[deviceId];
|
||||
}
|
||||
|
||||
|
||||
sd::CublasHelper* sd::CublasHelper::_INSTANCE = 0;
|
||||
}
|
|
@ -31,8 +31,6 @@ namespace sd {
|
|||
namespace ops {
|
||||
class ND4J_EXPORT HashHelper {
|
||||
private:
|
||||
static HashHelper* _INSTANCE;
|
||||
|
||||
Nd4jLong _byteTable[256];
|
||||
const Nd4jLong HSTART = 0xBB40E64DA205B064L;
|
||||
const Nd4jLong HMULT = 7664345821815920749L;
|
||||
|
@ -41,7 +39,7 @@ namespace sd {
|
|||
std::mutex _locker;
|
||||
|
||||
public:
|
||||
static HashHelper* getInstance();
|
||||
static HashHelper& getInstance();
|
||||
Nd4jLong getLongHash(std::string& str);
|
||||
};
|
||||
}
|
||||
|
|
|
@ -20,10 +20,9 @@
|
|||
|
||||
#include <helpers/BlasHelper.h>
|
||||
namespace sd {
|
||||
BlasHelper* BlasHelper::getInstance() {
|
||||
if (_instance == 0)
|
||||
_instance = new BlasHelper();
|
||||
return _instance;
|
||||
BlasHelper& BlasHelper::getInstance() {
|
||||
static BlasHelper instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
|
||||
|
@ -74,7 +73,7 @@ namespace sd {
|
|||
|
||||
template <>
|
||||
bool BlasHelper::hasGEMV<float>() {
|
||||
if (sd::Environment::getInstance()->blasFallback())
|
||||
if (sd::Environment::getInstance().blasFallback())
|
||||
return false;
|
||||
|
||||
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
||||
|
@ -86,7 +85,7 @@ namespace sd {
|
|||
|
||||
template <>
|
||||
bool BlasHelper::hasGEMV<double>() {
|
||||
if (sd::Environment::getInstance()->blasFallback())
|
||||
if (sd::Environment::getInstance().blasFallback())
|
||||
return false;
|
||||
|
||||
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
||||
|
@ -138,7 +137,7 @@ namespace sd {
|
|||
|
||||
bool BlasHelper::hasGEMV(const sd::DataType dtype) {
|
||||
if(dtype == DataType::FLOAT32) {
|
||||
if (sd::Environment::getInstance()->blasFallback())
|
||||
if (sd::Environment::getInstance().blasFallback())
|
||||
return false;
|
||||
|
||||
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
||||
|
@ -148,7 +147,7 @@ namespace sd {
|
|||
#endif
|
||||
}
|
||||
if(dtype == DataType::DOUBLE) {
|
||||
if (sd::Environment::getInstance()->blasFallback())
|
||||
if (sd::Environment::getInstance().blasFallback())
|
||||
return false;
|
||||
|
||||
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
||||
|
@ -162,7 +161,7 @@ namespace sd {
|
|||
|
||||
template <>
|
||||
bool BlasHelper::hasGEMM<float>() {
|
||||
if (sd::Environment::getInstance()->blasFallback())
|
||||
if (sd::Environment::getInstance().blasFallback())
|
||||
return false;
|
||||
|
||||
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
||||
|
@ -174,7 +173,7 @@ namespace sd {
|
|||
|
||||
template <>
|
||||
bool BlasHelper::hasGEMM<double>() {
|
||||
if (sd::Environment::getInstance()->blasFallback())
|
||||
if (sd::Environment::getInstance().blasFallback())
|
||||
return false;
|
||||
|
||||
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
||||
|
@ -226,7 +225,7 @@ namespace sd {
|
|||
|
||||
bool BlasHelper:: hasGEMM(const sd::DataType dtype) {
|
||||
if(dtype == DataType::FLOAT32) {
|
||||
if (sd::Environment::getInstance()->blasFallback())
|
||||
if (sd::Environment::getInstance().blasFallback())
|
||||
return false;
|
||||
|
||||
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
||||
|
@ -236,7 +235,7 @@ namespace sd {
|
|||
#endif
|
||||
}
|
||||
if(dtype == DataType::DOUBLE) {
|
||||
if (sd::Environment::getInstance()->blasFallback())
|
||||
if (sd::Environment::getInstance().blasFallback())
|
||||
return false;
|
||||
|
||||
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
||||
|
@ -251,7 +250,7 @@ namespace sd {
|
|||
|
||||
template <>
|
||||
bool BlasHelper::hasBatchedGEMM<float>() {
|
||||
if (sd::Environment::getInstance()->blasFallback())
|
||||
if (sd::Environment::getInstance().blasFallback())
|
||||
return false;
|
||||
|
||||
return _hasSgemmBatch;
|
||||
|
@ -259,7 +258,7 @@ namespace sd {
|
|||
|
||||
template <>
|
||||
bool BlasHelper::hasBatchedGEMM<double>() {
|
||||
if (sd::Environment::getInstance()->blasFallback())
|
||||
if (sd::Environment::getInstance().blasFallback())
|
||||
return false;
|
||||
|
||||
return _hasDgemmBatch;
|
||||
|
@ -362,6 +361,4 @@ namespace sd {
|
|||
|
||||
// destructor
|
||||
BlasHelper::~BlasHelper() noexcept { }
|
||||
|
||||
BlasHelper* BlasHelper::_instance = 0;
|
||||
}
|
||||
|
|
|
@ -32,7 +32,7 @@ namespace sd {
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
OmpLaunchHelper::OmpLaunchHelper(const Nd4jLong N, float desiredNumThreads) {
|
||||
|
||||
auto maxItersPerThread = Environment::getInstance()->elementwiseThreshold();
|
||||
auto maxItersPerThread = Environment::getInstance().elementwiseThreshold();
|
||||
|
||||
if(N < maxItersPerThread)
|
||||
_numThreads = 1;
|
||||
|
@ -45,7 +45,7 @@ OmpLaunchHelper::OmpLaunchHelper(const Nd4jLong N, float desiredNumThreads) {
|
|||
else
|
||||
desiredNumThreads = sd::math::nd4j_min<int>(omp_get_max_threads(), desiredNumThreads);
|
||||
#else
|
||||
desiredNumThreads = sd::Environment::getInstance()->maxThreads();
|
||||
desiredNumThreads = sd::Environment::getInstance().maxThreads();
|
||||
#endif
|
||||
_numThreads = sd::math::nd4j_min<int>(N / maxItersPerThread, desiredNumThreads);
|
||||
}
|
||||
|
@ -75,12 +75,12 @@ Nd4jLong OmpLaunchHelper::betterSpan(Nd4jLong N) {
|
|||
#ifdef _OPENMP
|
||||
return betterThreads(N, omp_get_max_threads());
|
||||
#else
|
||||
return betterThreads(N, sd::Environment::getInstance()->maxThreads());;
|
||||
return betterThreads(N, sd::Environment::getInstance().maxThreads());;
|
||||
#endif
|
||||
}
|
||||
|
||||
int OmpLaunchHelper::betterThreads(Nd4jLong N, int maxThreads) {
|
||||
auto t = Environment::getInstance()->elementwiseThreshold();
|
||||
auto t = Environment::getInstance().elementwiseThreshold();
|
||||
if (N < t)
|
||||
return 1;
|
||||
else {
|
||||
|
@ -92,7 +92,7 @@ Nd4jLong OmpLaunchHelper::betterSpan(Nd4jLong N) {
|
|||
#ifdef _OPENMP
|
||||
auto maxThreads = omp_get_max_threads();
|
||||
#else
|
||||
auto maxThreads = sd::Environment::getInstance()->maxThreads();
|
||||
auto maxThreads = sd::Environment::getInstance().maxThreads();
|
||||
#endif
|
||||
|
||||
// if there's only 1 thread allowed - nothing to do here
|
||||
|
@ -102,7 +102,7 @@ Nd4jLong OmpLaunchHelper::betterSpan(Nd4jLong N) {
|
|||
auto totalLength = tadLength * numTads;
|
||||
|
||||
// if array is tiny - no need to spawn any threeds
|
||||
if (totalLength < Environment::getInstance()->elementwiseThreshold())
|
||||
if (totalLength < Environment::getInstance().elementwiseThreshold())
|
||||
return 1;
|
||||
|
||||
// by default we're spawning as many threads we can, but not more than number of TADs
|
||||
|
|
|
@ -29,11 +29,9 @@ using namespace sd::graph;
|
|||
|
||||
namespace sd {
|
||||
|
||||
OpTracker* OpTracker::getInstance() {
|
||||
if (_INSTANCE == 0)
|
||||
_INSTANCE = new OpTracker();
|
||||
|
||||
return _INSTANCE;
|
||||
OpTracker& OpTracker::getInstance() {
|
||||
static OpTracker instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void OpTracker::storeOperation(sd::graph::OpType opType, const OpDescriptor& descriptor) {
|
||||
|
@ -118,6 +116,4 @@ namespace sd {
|
|||
|
||||
return _export.c_str();
|
||||
}
|
||||
|
||||
sd::OpTracker* sd::OpTracker::_INSTANCE = 0;
|
||||
}
|
||||
|
|
|
@ -130,7 +130,7 @@ std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const NDArray* a, cons
|
|||
Nd4jLong* outShapeInfo = ShapeBuilders::copyShapeInfoAndType(shapeInfo, dataType, true, workspace);
|
||||
ShapeDescriptor descriptor(outShapeInfo, dataType);
|
||||
RELEASE(outShapeInfo, workspace);
|
||||
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||
return ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary();
|
||||
}
|
||||
|
||||
const int rank = shape::rank(shapeInfo);
|
||||
|
@ -168,7 +168,7 @@ std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const NDArray* a, cons
|
|||
|
||||
ShapeDescriptor descriptor(outShapeInfo, dataType);
|
||||
RELEASE(outShapeInfo, workspace);
|
||||
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||
return ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary();
|
||||
}
|
||||
|
||||
const Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const NDArray& arr, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) {
|
||||
|
@ -207,20 +207,20 @@ std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const NDArray* a, cons
|
|||
|
||||
ShapeDescriptor descriptor(newShapeInfo, dataType);
|
||||
RELEASE(newShapeInfo, workspace);
|
||||
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||
return ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary();
|
||||
}
|
||||
else if(supportOldShapes) {
|
||||
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong);
|
||||
shape::shapeOldScalar(dataType, newShapeInfo, 'c');
|
||||
ShapeDescriptor descriptor(newShapeInfo, dataType);
|
||||
RELEASE(newShapeInfo, workspace);
|
||||
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||
return ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary();
|
||||
}
|
||||
else {
|
||||
newShapeInfo = ShapeBuilders::createScalarShapeInfo(dataType, workspace);
|
||||
ShapeDescriptor descriptor(newShapeInfo, dataType);
|
||||
RELEASE(newShapeInfo, workspace);
|
||||
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||
return ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -241,7 +241,7 @@ std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const NDArray* a, cons
|
|||
ShapeUtils::updateStridesAndType(newShapeInfo, shapeInfo, order);
|
||||
ShapeDescriptor descriptor(newShapeInfo, dataType);
|
||||
RELEASE(newShapeInfo, workspace);
|
||||
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||
return ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary();
|
||||
}
|
||||
|
||||
int newRank = rank - dimSize;
|
||||
|
@ -252,13 +252,13 @@ std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const NDArray* a, cons
|
|||
shape::shapeOldScalar(ArrayOptions::dataType(shapeInfo), newShapeInfo, 'c');
|
||||
ShapeDescriptor descriptor(newShapeInfo, dataType);
|
||||
RELEASE(newShapeInfo, workspace);
|
||||
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||
return ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary();
|
||||
}
|
||||
else {
|
||||
newShapeInfo = ShapeBuilders::createScalarShapeInfo(ArrayOptions::dataType(shapeInfo), workspace);
|
||||
ShapeDescriptor descriptor(newShapeInfo, dataType);
|
||||
RELEASE(newShapeInfo, workspace);
|
||||
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||
return ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -289,7 +289,7 @@ std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const NDArray* a, cons
|
|||
|
||||
ShapeDescriptor descriptor(newShapeInfo, dataType);
|
||||
RELEASE(newShapeInfo, workspace);
|
||||
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||
return ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
@ -341,7 +341,7 @@ std::vector<Nd4jLong> ShapeUtils::evalRepeatShape(int axis, const std::vector<in
|
|||
|
||||
RELEASE(shapeInfoNew, workspace);
|
||||
|
||||
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||
return ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
@ -486,7 +486,7 @@ bool ShapeUtils::areShapesBroadcastable(const Nd4jLong *shapeInfo1, const Nd4jLo
|
|||
|
||||
ShapeDescriptor descriptor(tmpShapeInfo);
|
||||
RELEASE(tmpShapeInfo, workspace);
|
||||
resultShapeInfo = ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||
resultShapeInfo = ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
@ -525,7 +525,7 @@ bool ShapeUtils::areShapesBroadcastable(const Nd4jLong *shapeInfo1, const Nd4jLo
|
|||
|
||||
ShapeDescriptor descriptor(tmpShapeInfo);
|
||||
RELEASE(tmpShapeInfo, workspace);
|
||||
resultShapeInfo = const_cast<Nd4jLong*>(ConstantShapeHelper::getInstance()->createShapeInfo(descriptor));
|
||||
resultShapeInfo = const_cast<Nd4jLong*>(ConstantShapeHelper::getInstance().createShapeInfo(descriptor));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
@ -594,7 +594,7 @@ bool ShapeUtils::areShapesBroadcastable(const Nd4jLong *shapeInfo1, const Nd4jLo
|
|||
|
||||
ShapeDescriptor descriptor(newShapeInfo);
|
||||
RELEASE(newShapeInfo, workspace);
|
||||
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||
return ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary();
|
||||
}
|
||||
|
||||
std::vector<Nd4jLong> ShapeUtils::pullShapeFromShapeInfo(const Nd4jLong *shapeInfo) {
|
||||
|
@ -745,7 +745,7 @@ std::vector<Nd4jLong> ShapeUtils::shapeAsVector(const Nd4jLong* shapeInfo) {
|
|||
|
||||
ShapeUtils::updateStridesAndType(outputShapeInfo, shapeInfo, shape::order(shapeInfo));
|
||||
|
||||
auto result = ConstantShapeHelper::getInstance()->createShapeInfo(outputShapeInfo);
|
||||
auto result = ConstantShapeHelper::getInstance().createShapeInfo(outputShapeInfo);
|
||||
RELEASE(outputShapeInfo, workspace);
|
||||
return result;
|
||||
}
|
||||
|
@ -832,7 +832,7 @@ std::vector<int> ShapeUtils::evalBroadcastBackwardAxis(const Nd4jLong *operandSh
|
|||
shape[1] = 1;
|
||||
}
|
||||
|
||||
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'f', 2, shape);
|
||||
auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(dtype, 'f', 2, shape);
|
||||
|
||||
RELEASE(shape, workspace);
|
||||
|
||||
|
|
|
@ -24,11 +24,9 @@
|
|||
namespace sd {
|
||||
namespace ops {
|
||||
|
||||
HashHelper* HashHelper::getInstance() {
|
||||
if (_INSTANCE == 0)
|
||||
_INSTANCE = new HashHelper();
|
||||
|
||||
return _INSTANCE;
|
||||
HashHelper& HashHelper::getInstance() {
|
||||
static HashHelper instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
Nd4jLong HashHelper::getLongHash(std::string& str) {
|
||||
|
@ -64,8 +62,6 @@ namespace sd {
|
|||
|
||||
return h;
|
||||
}
|
||||
|
||||
sd::ops::HashHelper* sd::ops::HashHelper::_INSTANCE = 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -32,9 +32,9 @@
|
|||
|
||||
#ifndef __CUDA_ARCH__
|
||||
|
||||
#define nd4j_debug(FORMAT, ...) if (sd::Environment::getInstance()->isDebug() && sd::Environment::getInstance()->isVerbose()) sd::Logger::info(FORMAT, __VA_ARGS__);
|
||||
#define nd4j_logger(FORMAT, ...) if (sd::Environment::getInstance()->isDebug() && sd::Environment::getInstance()->isVerbose()) sd::Logger::info(FORMAT, __VA_ARGS__);
|
||||
#define nd4j_verbose(FORMAT, ...) if (sd::Environment::getInstance()->isVerbose()) sd::Logger::info(FORMAT, __VA_ARGS__);
|
||||
#define nd4j_debug(FORMAT, ...) if (sd::Environment::getInstance().isDebug() && sd::Environment::getInstance().isVerbose()) sd::Logger::info(FORMAT, __VA_ARGS__);
|
||||
#define nd4j_logger(FORMAT, ...) if (sd::Environment::getInstance().isDebug() && sd::Environment::getInstance().isVerbose()) sd::Logger::info(FORMAT, __VA_ARGS__);
|
||||
#define nd4j_verbose(FORMAT, ...) if (sd::Environment::getInstance().isVerbose()) sd::Logger::info(FORMAT, __VA_ARGS__);
|
||||
#define nd4j_printf(FORMAT, ...) sd::Logger::info(FORMAT, __VA_ARGS__);
|
||||
#define nd4j_printv(FORMAT, VECTOR) sd::Logger::printv(FORMAT, VECTOR);
|
||||
|
||||
|
|
|
@ -384,9 +384,9 @@ namespace shape {
|
|||
* @param rank the rank of the shape
|
||||
*/
|
||||
|
||||
ND4J_EXPORT _CUDA_HD int isMatrix(Nd4jLong *shape, int rank);
|
||||
ND4J_EXPORT _CUDA_HD int isMatrix(const Nd4jLong *shape, int rank);
|
||||
|
||||
INLINEDEF _CUDA_HD int isMatrix(Nd4jLong *shapeInfo);
|
||||
INLINEDEF _CUDA_HD int isMatrix(const Nd4jLong *shapeInfo);
|
||||
/**
|
||||
* Returns the shape portion of an information
|
||||
* buffer
|
||||
|
@ -2346,7 +2346,7 @@ INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape)
|
|||
* @param shape the shape of the array
|
||||
* @param rank the rank of the shape
|
||||
*/
|
||||
INLINEDEF _CUDA_HD int isMatrix(Nd4jLong *shape, int rank) {
|
||||
INLINEDEF _CUDA_HD int isMatrix(const Nd4jLong *shape, int rank) {
|
||||
if (rank > 2)
|
||||
return 0;
|
||||
else if (rank <= 2) {
|
||||
|
@ -2357,7 +2357,7 @@ INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape)
|
|||
return 1;
|
||||
}
|
||||
|
||||
INLINEDEF _CUDA_HD int isMatrix(Nd4jLong *shapeInfo) {
|
||||
INLINEDEF _CUDA_HD int isMatrix(const Nd4jLong *shapeInfo) {
|
||||
return isMatrix(shape::shapeOf(shapeInfo),shape::rank(shapeInfo));
|
||||
}
|
||||
|
||||
|
|
|
@ -1567,8 +1567,9 @@ ND4J_EXPORT void inspectArray(Nd4jPointer *extraPointers, Nd4jPointer buffer, Nd
|
|||
|
||||
|
||||
typedef sd::ConstantDataBuffer OpaqueConstantDataBuffer;
|
||||
typedef sd::ConstantShapeBuffer OpaqueConstantShapeBuffer;
|
||||
|
||||
ND4J_EXPORT OpaqueConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strides, sd::DataType dtype, char order, Nd4jLong ews, bool empty);
|
||||
ND4J_EXPORT OpaqueConstantShapeBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strides, sd::DataType dtype, char order, Nd4jLong ews, bool empty);
|
||||
|
||||
ND4J_EXPORT OpaqueConstantDataBuffer* constantBufferLong(sd::DataType dtype, Nd4jLong const* data, int length);
|
||||
ND4J_EXPORT OpaqueConstantDataBuffer* constantBufferDouble(sd::DataType dtype, double *data, int length);
|
||||
|
@ -1577,9 +1578,12 @@ ND4J_EXPORT OpaqueConstantDataBuffer* constantBuffer(sd::DataType dtype, sd::Con
|
|||
ND4J_EXPORT Nd4jPointer getConstantDataBufferPrimary(OpaqueConstantDataBuffer* dbf);
|
||||
ND4J_EXPORT Nd4jPointer getConstantDataBufferSpecial(OpaqueConstantDataBuffer* dbf);
|
||||
ND4J_EXPORT Nd4jLong getConstantDataBufferLength(OpaqueConstantDataBuffer* dbf);
|
||||
ND4J_EXPORT Nd4jLong getConstantDataBufferSizeOf(OpaqueConstantDataBuffer* dbf);
|
||||
|
||||
ND4J_EXPORT void deleteShapeBuffer(OpaqueConstantDataBuffer* ptr);
|
||||
ND4J_EXPORT Nd4jPointer getConstantShapeBufferPrimary(OpaqueConstantShapeBuffer* dbf);
|
||||
ND4J_EXPORT Nd4jPointer getConstantShapeBufferSpecial(OpaqueConstantShapeBuffer* dbf);
|
||||
|
||||
ND4J_EXPORT void deleteConstantShapeBuffer(OpaqueConstantShapeBuffer* ptr);
|
||||
ND4J_EXPORT void deleteConstantDataBuffer(OpaqueConstantDataBuffer* ptr);
|
||||
|
||||
typedef sd::graph::Context OpaqueContext;
|
||||
typedef sd::graph::RandomGenerator OpaqueRandomGenerator;
|
||||
|
|
|
@ -245,7 +245,7 @@ void NativeOpExecutioner::execInverseBroadcast(sd::LaunchContext *lc,
|
|||
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
|
||||
return;
|
||||
|
||||
if (!sd::Environment::getInstance()->isExperimentalBuild())
|
||||
if (!sd::Environment::getInstance().isExperimentalBuild())
|
||||
if ((yType != xType && yType != sd::DataType::BOOL) || xType != zType)
|
||||
throw sd::datatype_exception::build("NativeOps::execBroadcast both operands must have same data type", xType, yType);
|
||||
|
||||
|
@ -338,7 +338,7 @@ void NativeOpExecutioner::execInverseBroadcastBool(sd::LaunchContext *lc,
|
|||
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
|
||||
return;
|
||||
|
||||
if (!sd::Environment::getInstance()->isExperimentalBuild())
|
||||
if (!sd::Environment::getInstance().isExperimentalBuild())
|
||||
if (yType != xType || sd::DataType::BOOL != zType)
|
||||
throw sd::datatype_exception::build("NativeOps::execInverseBroadcastBool both operands must have same data type", xType, yType);
|
||||
|
||||
|
@ -496,7 +496,7 @@ void NativeOpExecutioner::execPairwiseTransform(sd::LaunchContext *lc,
|
|||
};
|
||||
|
||||
auto zLen = shape::length(hZShapeInfo);
|
||||
samediff::Threads::parallel_for(func, 0, zLen, 1, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
||||
samediff::Threads::parallel_for(func, 0, zLen, 1, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance().maxMasterThreads())));
|
||||
|
||||
#endif
|
||||
}
|
||||
|
@ -531,7 +531,7 @@ void NativeOpExecutioner::execPairwiseBoolTransform(sd::LaunchContext *lc,
|
|||
};
|
||||
|
||||
auto zLen = shape::length(hZShapeInfo);
|
||||
samediff::Threads::parallel_for(func, 0, zLen, 1, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
||||
samediff::Threads::parallel_for(func, 0, zLen, 1, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance().maxMasterThreads())));
|
||||
|
||||
}
|
||||
|
||||
|
@ -564,7 +564,7 @@ void NativeOpExecutioner::execPairwiseIntTransform(sd::LaunchContext *lc,
|
|||
};
|
||||
|
||||
auto zLen = shape::length(hZShapeInfo);
|
||||
samediff::Threads::parallel_for(func, 0, zLen, 1, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
||||
samediff::Threads::parallel_for(func, 0, zLen, 1, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance().maxMasterThreads())));
|
||||
|
||||
}
|
||||
|
||||
|
@ -603,7 +603,7 @@ void NativeOpExecutioner::execReduceFloat(sd::LaunchContext *lc,
|
|||
|
||||
const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
|
||||
|
||||
samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance()->maxMasterThreads());
|
||||
samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance().maxMasterThreads());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
@ -631,7 +631,7 @@ void NativeOpExecutioner::execReduceSame(sd::LaunchContext *lc,
|
|||
|
||||
const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
|
||||
|
||||
samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance()->maxMasterThreads());
|
||||
samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance().maxMasterThreads());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
@ -659,7 +659,7 @@ void NativeOpExecutioner::execReduceBool(sd::LaunchContext *lc,
|
|||
|
||||
const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
|
||||
|
||||
samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance()->maxMasterThreads());
|
||||
samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance().maxMasterThreads());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
@ -687,7 +687,7 @@ void NativeOpExecutioner::execReduceLong(sd::LaunchContext *lc,
|
|||
|
||||
const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
|
||||
|
||||
samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance()->maxMasterThreads());
|
||||
samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance().maxMasterThreads());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
@ -844,13 +844,13 @@ void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc,
|
|||
sd::TadPack tadPack;
|
||||
|
||||
if(xLen == yLen) {
|
||||
tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength);
|
||||
tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension, dimensionLength);
|
||||
}
|
||||
else if(yLen > xLen) {
|
||||
tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hYShapeInfo, dimension, dimensionLength);
|
||||
tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hYShapeInfo, dimension, dimensionLength);
|
||||
}
|
||||
else {
|
||||
tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength);
|
||||
tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension, dimensionLength);
|
||||
}
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
|
@ -878,7 +878,7 @@ void NativeOpExecutioner::execReduce3All(sd::LaunchContext *lc,
|
|||
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
|
||||
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
|
||||
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength);
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension, dimensionLength);
|
||||
|
||||
// TODO: make it 2d
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
|
@ -911,13 +911,13 @@ void NativeOpExecutioner::execReduce3TAD(sd::LaunchContext *lc,
|
|||
sd::TadPack tadPack;
|
||||
|
||||
if(xLen == yLen) {
|
||||
tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength);
|
||||
tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension, dimensionLength);
|
||||
}
|
||||
else if(yLen > xLen) {
|
||||
tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hYShapeInfo, dimension, dimensionLength);
|
||||
tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hYShapeInfo, dimension, dimensionLength);
|
||||
}
|
||||
else {
|
||||
tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength);
|
||||
tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension, dimensionLength);
|
||||
}
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
|
@ -969,7 +969,7 @@ void NativeOpExecutioner::execScalar(sd::LaunchContext *lc,
|
|||
};
|
||||
|
||||
auto zLen = shape::length(hZShapeInfo);
|
||||
samediff::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
||||
samediff::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance().maxMasterThreads())));
|
||||
|
||||
#endif
|
||||
}
|
||||
|
@ -1006,7 +1006,7 @@ void NativeOpExecutioner::execScalar(sd::LaunchContext *lc,
|
|||
};
|
||||
|
||||
auto yLen = shape::length(hScalarShapeInfo);
|
||||
samediff::Threads::parallel_tad(func, 0, yLen, 1, sd::math::nd4j_min<int>(yLen, sd::Environment::getInstance()->maxMasterThreads()));
|
||||
samediff::Threads::parallel_tad(func, 0, yLen, 1, sd::math::nd4j_min<int>(yLen, sd::Environment::getInstance().maxMasterThreads()));
|
||||
|
||||
#endif
|
||||
}
|
||||
|
@ -1041,7 +1041,7 @@ void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc,
|
|||
};
|
||||
|
||||
auto zLen = shape::length(hZShapeInfo);
|
||||
samediff::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
||||
samediff::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance().maxMasterThreads())));
|
||||
|
||||
}
|
||||
|
||||
|
@ -1077,7 +1077,7 @@ void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc,
|
|||
};
|
||||
|
||||
auto yLen = shape::length(hScalarShapeInfo);
|
||||
samediff::Threads::parallel_tad(func, 0, yLen, 1, sd::math::nd4j_min<int>(yLen, sd::Environment::getInstance()->maxMasterThreads()));
|
||||
samediff::Threads::parallel_tad(func, 0, yLen, 1, sd::math::nd4j_min<int>(yLen, sd::Environment::getInstance().maxMasterThreads()));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1110,7 +1110,7 @@ void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc,
|
|||
};
|
||||
|
||||
auto zLen = shape::length(hZShapeInfo);
|
||||
samediff::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
||||
samediff::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(zLen / 1024, sd::Environment::getInstance().maxMasterThreads())));
|
||||
|
||||
}
|
||||
|
||||
|
@ -1146,7 +1146,7 @@ void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc,
|
|||
};
|
||||
|
||||
auto yLen = shape::length(hScalarShapeInfo);
|
||||
samediff::Threads::parallel_tad(func, 0, yLen, 1, sd::math::nd4j_min<int>(yLen, sd::Environment::getInstance()->maxMasterThreads()));
|
||||
samediff::Threads::parallel_tad(func, 0, yLen, 1, sd::math::nd4j_min<int>(yLen, sd::Environment::getInstance().maxMasterThreads()));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1259,7 +1259,7 @@ void NativeOpExecutioner::execTransformFloat(sd::LaunchContext *lc,
|
|||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, FLOAT_TYPES);
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
||||
samediff::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance().maxMasterThreads())));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1281,7 +1281,7 @@ void NativeOpExecutioner::execTransformBool(sd::LaunchContext *lc,
|
|||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformBool, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, BOOL_TYPES);
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
||||
samediff::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance().maxMasterThreads())));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1310,7 +1310,7 @@ void NativeOpExecutioner::execTransformAny(sd::LaunchContext *lc,
|
|||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
||||
samediff::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance().maxMasterThreads())));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1333,7 +1333,7 @@ void NativeOpExecutioner::execTransformSame(sd::LaunchContext *lc,
|
|||
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformSame, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES);
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
||||
samediff::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance().maxMasterThreads())));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1355,7 +1355,7 @@ void NativeOpExecutioner::execTransformStrict(sd::LaunchContext *lc,
|
|||
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), FLOAT_TYPES);
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
||||
samediff::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance().maxMasterThreads())));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -85,12 +85,12 @@ using namespace sd;
|
|||
|
||||
void setElementThreshold(int num) {
|
||||
if (num > 0)
|
||||
sd::Environment::getInstance()->setElementwiseThreshold(num);
|
||||
sd::Environment::getInstance().setElementwiseThreshold(num);
|
||||
}
|
||||
|
||||
void setTADThreshold(int num) {
|
||||
if (num > 0)
|
||||
sd::Environment::getInstance()->setTadThreshold(num);
|
||||
sd::Environment::getInstance().setTadThreshold(num);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -133,7 +133,7 @@ void execIndexReduce(Nd4jPointer *extraPointers,int opNum,
|
|||
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
||||
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
||||
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension,
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension,
|
||||
dimensionLength);
|
||||
|
||||
auto hTADShapeInfo = tadPack.primaryShapeInfo();
|
||||
|
@ -184,8 +184,8 @@ void execBroadcast(Nd4jPointer *extraPointers,
|
|||
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
||||
auto dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
||||
|
||||
auto tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength);
|
||||
auto tadPackZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(hZShapeInfo, dimension, dimensionLength);
|
||||
auto tadPackX = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension, dimensionLength);
|
||||
auto tadPackZ = sd::ConstantTadHelper::getInstance().tadForDimensions(hZShapeInfo, dimension, dimensionLength);
|
||||
|
||||
auto hTADShapeInfo = tadPackX.primaryShapeInfo();
|
||||
auto hTADOffsets = tadPackX.primaryOffsets();
|
||||
|
@ -223,8 +223,8 @@ void execBroadcastBool(Nd4jPointer *extraPointers,
|
|||
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
||||
auto dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
||||
|
||||
auto tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength);
|
||||
auto tadPackZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(hZShapeInfo, dimension, dimensionLength);
|
||||
auto tadPackX = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension, dimensionLength);
|
||||
auto tadPackZ = sd::ConstantTadHelper::getInstance().tadForDimensions(hZShapeInfo, dimension, dimensionLength);
|
||||
|
||||
auto hTADShapeInfo = tadPackX.primaryShapeInfo();
|
||||
auto hTADOffsets = tadPackX.primaryOffsets();
|
||||
|
@ -450,7 +450,7 @@ void execReduceFloat2(Nd4jPointer *extraPointers,
|
|||
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
||||
auto dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
||||
|
||||
auto tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength);
|
||||
auto tadPackX = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension, dimensionLength);
|
||||
|
||||
auto hTADShapeInfo = tadPackX.primaryShapeInfo();
|
||||
auto hTADOffsets = tadPackX.primaryOffsets();
|
||||
|
@ -485,7 +485,7 @@ void execReduceBool2(Nd4jPointer *extraPointers,
|
|||
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
||||
auto dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
||||
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension,
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension,
|
||||
dimensionLength);
|
||||
|
||||
auto hTADShapeInfo = tadPack.primaryShapeInfo();
|
||||
|
@ -521,7 +521,7 @@ void execReduceSame2(Nd4jPointer *extraPointers,
|
|||
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
||||
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
||||
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension,
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension,
|
||||
dimensionLength);
|
||||
|
||||
auto hTADShapeInfo = tadPack.primaryShapeInfo();
|
||||
|
@ -557,7 +557,7 @@ void execReduceLong2(Nd4jPointer *extraPointers,
|
|||
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
||||
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
||||
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength);
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension, dimensionLength);
|
||||
|
||||
auto hTADShapeInfo = tadPack.primaryShapeInfo();
|
||||
auto hTADOffsets = tadPack.primaryOffsets();
|
||||
|
@ -663,7 +663,7 @@ void execReduce3Tad(Nd4jPointer *extraPointers,
|
|||
yTadOnlyShapeInfo, yTadOffsets);
|
||||
} else {
|
||||
// going tad-way
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension,
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension,
|
||||
dimensionLength);
|
||||
|
||||
auto hTADShapeInfo = tadPack.primaryShapeInfo();
|
||||
|
@ -1060,7 +1060,7 @@ void initializeDevicesAndFunctions() {
|
|||
}
|
||||
|
||||
void initializeFunctions(Nd4jPointer *functions) {
|
||||
sd::BlasHelper::getInstance()->initializeFunctions(functions);
|
||||
sd::BlasHelper::getInstance().initializeFunctions(functions);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1208,11 +1208,11 @@ int getAvailableDevices() {
|
|||
}
|
||||
|
||||
void enableDebugMode(bool reallyEnable) {
|
||||
sd::Environment::getInstance()->setDebug(reallyEnable);
|
||||
sd::Environment::getInstance().setDebug(reallyEnable);
|
||||
}
|
||||
|
||||
void enableVerboseMode(bool reallyEnable) {
|
||||
sd::Environment::getInstance()->setVerbose(reallyEnable);
|
||||
sd::Environment::getInstance().setVerbose(reallyEnable);
|
||||
}
|
||||
|
||||
void setGridLimit(int gridSize) {
|
||||
|
@ -1222,7 +1222,7 @@ void setGridLimit(int gridSize) {
|
|||
sd::TadPack* tadOnlyShapeInfo(Nd4jLong const* hXShapeInfo, int *dimension, int dimensionLength) {
|
||||
auto pack = new TadPack();
|
||||
try {
|
||||
*pack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength);
|
||||
*pack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension, dimensionLength);
|
||||
} catch (std::exception &e) {
|
||||
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
||||
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
||||
|
@ -1285,7 +1285,7 @@ void pullRowsGeneric(void *vx,
|
|||
|
||||
int elementsPerThread = n / TAD_THRESHOLD;
|
||||
int _threads = sd::math::nd4j_max<int>(1, elementsPerThread);
|
||||
_threads = sd::math::nd4j_min<int>(_threads, sd::Environment::getInstance()->maxThreads());
|
||||
_threads = sd::math::nd4j_min<int>(_threads, sd::Environment::getInstance().maxThreads());
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto idx = start; idx < stop; idx++) {
|
||||
|
@ -1557,7 +1557,7 @@ void shuffle(Nd4jPointer *extras,
|
|||
|
||||
|
||||
bool isExperimentalEnabled() {
|
||||
return sd::Environment::getInstance()->isExperimentalBuild();
|
||||
return sd::Environment::getInstance().isExperimentalBuild();
|
||||
}
|
||||
|
||||
|
||||
|
@ -1920,7 +1920,7 @@ Nd4jPointer getResultWrapperPointer(sd::graph::ResultWrapper* ptr) {
|
|||
}
|
||||
|
||||
const char* getAllCustomOps() {
|
||||
return sd::ops::OpRegistrator::getInstance()->getAllCustomOperations();
|
||||
return sd::ops::OpRegistrator::getInstance().getAllCustomOperations();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -2016,7 +2016,7 @@ sd::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, sd::ops::Decla
|
|||
|
||||
sd::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) {
|
||||
try {
|
||||
auto op = sd::ops::OpRegistrator::getInstance()->getOperation(hash);
|
||||
auto op = sd::ops::OpRegistrator::getInstance().getOperation(hash);
|
||||
|
||||
return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs, dArgs, numDArgs);
|
||||
} catch (std::exception &e) {
|
||||
|
@ -2047,7 +2047,7 @@ sd::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, sd::ops::Decla
|
|||
|
||||
sd::ShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs) {
|
||||
try {
|
||||
auto op = sd::ops::OpRegistrator::getInstance()->getOperation(hash);
|
||||
auto op = sd::ops::OpRegistrator::getInstance().getOperation(hash);
|
||||
|
||||
return _calculateOutputShapes(extraPointers, op, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs);
|
||||
} catch (std::exception &e) {
|
||||
|
@ -2059,7 +2059,7 @@ sd::ShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash,
|
|||
|
||||
int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer opContext) {
|
||||
try {
|
||||
auto op = sd::ops::OpRegistrator::getInstance()->getOperation(hash);
|
||||
auto op = sd::ops::OpRegistrator::getInstance().getOperation(hash);
|
||||
auto context = reinterpret_cast<Context *>(opContext);
|
||||
|
||||
return op->execute(context);
|
||||
|
@ -2157,7 +2157,7 @@ Nd4jStatus realExec(sd::ops::DeclarableOp* op, Nd4jPointer* extraPointers, Nd4jL
|
|||
|
||||
int execCustomOp(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace) {
|
||||
try {
|
||||
auto op = sd::ops::OpRegistrator::getInstance()->getOperation(hash);
|
||||
auto op = sd::ops::OpRegistrator::getInstance().getOperation(hash);
|
||||
return realExec(op, extraPointers, hash, inputBuffers, inputShapes, numInputs, outputBuffers, outputShapes, numOutputs, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs, isInplace);
|
||||
} catch (std::exception &e) {
|
||||
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
||||
|
@ -2170,7 +2170,7 @@ int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer flat
|
|||
try {
|
||||
auto graph = sd::graph::GraphExecutioner::importFromFlatPointer(flatBufferPointer);
|
||||
|
||||
sd::graph::GraphHolder::getInstance()->registerGraph(graphId, graph);
|
||||
sd::graph::GraphHolder::getInstance().registerGraph(graphId, graph);
|
||||
|
||||
return ND4J_STATUS_OK;
|
||||
} catch (std::exception &e) {
|
||||
|
@ -2181,7 +2181,7 @@ int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer flat
|
|||
}
|
||||
|
||||
static VariablesSet* executeStoredGraphT(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs) {
|
||||
auto graph = sd::graph::GraphHolder::getInstance()->cloneGraph(graphId);
|
||||
auto graph = sd::graph::GraphHolder::getInstance().cloneGraph(graphId);
|
||||
auto varSpace = graph->getVariableSpace();
|
||||
|
||||
std::vector<sd::NDArray*> handles;
|
||||
|
@ -2264,7 +2264,7 @@ void* getVariableBuffer(sd::graph::Variable* variable) {
|
|||
|
||||
int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId) {
|
||||
|
||||
sd::graph::GraphHolder::getInstance()->dropGraphAny(graphId);
|
||||
sd::graph::GraphHolder::getInstance().dropGraphAny(graphId);
|
||||
|
||||
return sd::Status::OK();
|
||||
}
|
||||
|
@ -2294,7 +2294,7 @@ void deleteVariablesSet(sd::graph::VariablesSet* pointer) {
|
|||
}
|
||||
|
||||
const char* getAllOperations() {
|
||||
return sd::OpTracker::getInstance()->exportOperations();
|
||||
return sd::OpTracker::getInstance().exportOperations();
|
||||
}
|
||||
|
||||
|
||||
|
@ -2694,10 +2694,10 @@ void tryPointer(Nd4jPointer extra, Nd4jPointer p, int len) {
|
|||
}
|
||||
}
|
||||
|
||||
sd::ConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strides, sd::DataType dtype, char order, Nd4jLong ews, bool empty) {
|
||||
sd::ConstantShapeBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strides, sd::DataType dtype, char order, Nd4jLong ews, bool empty) {
|
||||
try {
|
||||
auto buffer = new ConstantDataBuffer();
|
||||
*buffer = sd::ConstantShapeHelper::getInstance()->bufferForShapeInfo(
|
||||
auto buffer = new ConstantShapeBuffer();
|
||||
*buffer = sd::ConstantShapeHelper::getInstance().bufferForShapeInfo(
|
||||
ShapeDescriptor(dtype, order, shape, strides, rank, ews, empty));
|
||||
return buffer;
|
||||
} catch (std::exception &e) {
|
||||
|
@ -2707,10 +2707,14 @@ sd::ConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strides
|
|||
}
|
||||
}
|
||||
|
||||
void deleteShapeBuffer(sd::ConstantDataBuffer* ptr) {
|
||||
void deleteConstantShapeBuffer(sd::ConstantShapeBuffer* ptr) {
|
||||
delete ptr;
|
||||
}
|
||||
|
||||
void deleteConstantDataBuffer(sd::ConstantDataBuffer* ptr) {
|
||||
delete ptr;
|
||||
}
|
||||
|
||||
void deleteTadPack(sd::TadPack* ptr) {
|
||||
delete ptr;
|
||||
}
|
||||
|
@ -2725,7 +2729,7 @@ sd::ConstantDataBuffer* constantBufferDouble(sd::DataType dtype, double *data, i
|
|||
|
||||
sd::ConstantDataBuffer* constantBuffer(sd::DataType dtype, sd::ConstantDescriptor *descriptor) {
|
||||
try {
|
||||
return sd::ConstantHelper::getInstance()->constantBuffer(*descriptor, dtype);
|
||||
return sd::ConstantHelper::getInstance().constantBuffer(*descriptor, dtype);
|
||||
} catch (std::exception &e) {
|
||||
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
||||
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
||||
|
@ -2733,6 +2737,14 @@ sd::ConstantDataBuffer* constantBuffer(sd::DataType dtype, sd::ConstantDescripto
|
|||
}
|
||||
}
|
||||
|
||||
Nd4jPointer getConstantShapeBufferPrimary(sd::ConstantShapeBuffer* dbf) {
|
||||
return const_cast<Nd4jLong*>(dbf->primary());
|
||||
}
|
||||
|
||||
Nd4jPointer getConstantShapeBufferSpecial(sd::ConstantShapeBuffer* dbf) {
|
||||
return const_cast<Nd4jLong*>(dbf->special());
|
||||
}
|
||||
|
||||
Nd4jPointer getConstantDataBufferPrimary(sd::ConstantDataBuffer* dbf) {
|
||||
return dbf->primary();
|
||||
}
|
||||
|
@ -2884,7 +2896,7 @@ Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray) {
|
|||
} else {
|
||||
shapeBuffer = sd::ShapeBuilders::createShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape);
|
||||
}
|
||||
return const_cast<Nd4jLong*>(sd::ConstantShapeHelper::getInstance()->createFromExisting(shapeBuffer, true));
|
||||
return const_cast<Nd4jLong*>(sd::ConstantShapeHelper::getInstance().createFromExisting(shapeBuffer, true));
|
||||
} catch (std::exception &e) {
|
||||
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
||||
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
||||
|
@ -2983,7 +2995,7 @@ const char* runLightBenchmarkSuit(bool printOut) {
|
|||
}
|
||||
|
||||
Nd4jLong getCachedMemory(int deviceId) {
|
||||
return sd::ConstantHelper::getInstance()->getCachedAmount(deviceId);
|
||||
return sd::ConstantHelper::getInstance().getCachedAmount(deviceId);
|
||||
}
|
||||
|
||||
const char* runFullBenchmarkSuit(bool printOut) {
|
||||
|
|
|
@ -252,7 +252,7 @@ void NativeOpExecutioner::execBroadcastBool(sd::LaunchContext *lc,
|
|||
if (yType != xType)
|
||||
throw std::runtime_error("NativeOpExecutioner::execBroadcastBool requires both X & Y operands to have same type");
|
||||
|
||||
if (sd::Environment::getInstance()->isDebugAndVerbose())
|
||||
if (sd::Environment::getInstance().isDebugAndVerbose())
|
||||
printf("F3B opNum:[%i]\n", opNum);
|
||||
|
||||
dim3 launchDims(256, 256, 1024);
|
||||
|
@ -437,7 +437,7 @@ void NativeOpExecutioner::execInverseBroadcastInt(sd::LaunchContext *lc,
|
|||
if (yType != xType || zType != xType)
|
||||
throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires both X & Y operands to have same type");
|
||||
|
||||
if (sd::Environment::getInstance()->isDebugAndVerbose())
|
||||
if (sd::Environment::getInstance().isDebugAndVerbose())
|
||||
printf("F3BI opNum:[%i]\n", opNum);
|
||||
|
||||
dim3 launchDims(256, 256, 1024);
|
||||
|
@ -583,7 +583,7 @@ void NativeOpExecutioner::execReduceSame(sd::LaunchContext *lc,
|
|||
auto stream = lc->getCudaStream();
|
||||
auto reductionPointer = lc->getReductionPointer();
|
||||
|
||||
if (sd::Environment::getInstance()->isDebugAndVerbose())
|
||||
if (sd::Environment::getInstance().isDebugAndVerbose())
|
||||
printf("SF7 opNum:[%i]\n", opNum);
|
||||
|
||||
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
|
||||
|
@ -618,7 +618,7 @@ void NativeOpExecutioner::execReduceLong(sd::LaunchContext *lc,
|
|||
auto stream = lc->getCudaStream();
|
||||
auto reductionPointer = lc->getReductionPointer();
|
||||
|
||||
if (sd::Environment::getInstance()->isDebugAndVerbose())
|
||||
if (sd::Environment::getInstance().isDebugAndVerbose())
|
||||
printf("LF7 opNum:[%i]\n", opNum);
|
||||
|
||||
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
|
||||
|
@ -654,7 +654,7 @@ void NativeOpExecutioner::execReduceBool(sd::LaunchContext *lc,
|
|||
auto stream = lc->getCudaStream();
|
||||
auto reductionPointer = lc->getReductionPointer();
|
||||
|
||||
if (sd::Environment::getInstance()->isDebugAndVerbose())
|
||||
if (sd::Environment::getInstance().isDebugAndVerbose())
|
||||
printf("BF7 opNum:[%i]\n", opNum);
|
||||
|
||||
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
|
||||
|
@ -701,7 +701,7 @@ void NativeOpExecutioner::execIndexReduce(sd::LaunchContext *lc,
|
|||
auto reductionPointer = lc->getReductionPointer();
|
||||
auto allocationPointer = lc->getAllocationPointer();
|
||||
|
||||
if (sd::Environment::getInstance()->isDebugAndVerbose())
|
||||
if (sd::Environment::getInstance().isDebugAndVerbose())
|
||||
printf("F2 opNum:[%i]\n", opNum);
|
||||
|
||||
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
|
||||
|
@ -745,7 +745,7 @@ void NativeOpExecutioner::execReduceFloat(sd::LaunchContext *lc,
|
|||
auto stream = lc->getCudaStream();
|
||||
auto reductionPointer = lc->getReductionPointer();
|
||||
|
||||
if (sd::Environment::getInstance()->isDebugAndVerbose())
|
||||
if (sd::Environment::getInstance().isDebugAndVerbose())
|
||||
printf("F8 opNum:[%i]\n", opNum);
|
||||
|
||||
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
|
||||
|
@ -780,7 +780,7 @@ void NativeOpExecutioner::execIndexReduceScalar(sd::LaunchContext *lc,
|
|||
void *hZ, Nd4jLong const* hZShapeInfo,
|
||||
void *dZ, Nd4jLong const* dZShapeInfo){
|
||||
|
||||
if (sd::Environment::getInstance()->isDebug())
|
||||
if (sd::Environment::getInstance().isDebug())
|
||||
printf("F1 opNum:[%i]\n", opNum);
|
||||
|
||||
auto stream = lc->getCudaStream();
|
||||
|
@ -792,7 +792,7 @@ void NativeOpExecutioner::execIndexReduceScalar(sd::LaunchContext *lc,
|
|||
auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth);
|
||||
dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768);
|
||||
|
||||
if (sd::Environment::getInstance()->isDebugAndVerbose() && launchDims.x == 1)
|
||||
if (sd::Environment::getInstance().isDebugAndVerbose() && launchDims.x == 1)
|
||||
printf("AF1 opNum:[%i]\n", opNum);
|
||||
|
||||
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
|
||||
|
@ -1649,12 +1649,12 @@ void NativeOpExecutioner::execReduce3All(sd::LaunchContext *lc,
|
|||
auto allocationPointer = lc->getAllocationPointer();
|
||||
auto reductionPointer = lc->getReductionPointer();
|
||||
|
||||
if (sd::Environment::getInstance()->isDebugAndVerbose())
|
||||
if (sd::Environment::getInstance().isDebugAndVerbose())
|
||||
printf("D119 opNum:[%i]\n", opNum);
|
||||
|
||||
dim3 launchDims(shape::length(hZShapeInfo), 256, 32768);
|
||||
|
||||
if (sd::Environment::getInstance()->isVerbose() && launchDims.x == 1)
|
||||
if (sd::Environment::getInstance().isVerbose() && launchDims.x == 1)
|
||||
printf("AD119 opNum:[%i]\n", opNum);
|
||||
|
||||
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
|
||||
|
|
|
@ -237,9 +237,9 @@ void execPairwiseTransform( Nd4jPointer *extraPointers,
|
|||
InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY});
|
||||
|
||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||
NativeOpExecutioner::execPairwiseTransform(&lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT<Nd4jLong>(), extraParams);
|
||||
NativeOpExecutioner::execPairwiseTransform(&lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
||||
dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hYShapeInfo).special(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(), extraParams);
|
||||
|
||||
InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY});
|
||||
} catch (std::exception &e) {
|
||||
|
@ -260,9 +260,9 @@ void execPairwiseTransformBool(Nd4jPointer *extraPointers,
|
|||
|
||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||
NativeOpExecutioner::execPairwiseBoolTransform(&lc, opNum,
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
||||
dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hYShapeInfo).special(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(),
|
||||
extraParams);
|
||||
|
||||
InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY});
|
||||
|
@ -284,9 +284,9 @@ void execSummaryStatsScalar(Nd4jPointer *extraPointers,
|
|||
|
||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||
NativeOpExecutioner::execSummaryStatsScalar(&lc, opNum,
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
||||
extraParams,
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(),
|
||||
biasCorrected);
|
||||
|
||||
InteropDataBuffer::registerSpecialUse({dbZ}, {dbX});
|
||||
|
@ -319,9 +319,9 @@ void execBroadcastBool(Nd4jPointer *extraPointers,
|
|||
|
||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||
NativeOpExecutioner::execBroadcastBool(&lc, opNum,
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
||||
dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hYShapeInfo).special(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(),
|
||||
extraParams,
|
||||
dimension, dimensionLength,
|
||||
tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ);
|
||||
|
@ -373,9 +373,9 @@ void execBroadcast(
|
|||
|
||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||
NativeOpExecutioner::execBroadcast(&lc, opNum,
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
||||
dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hYShapeInfo).special(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(),
|
||||
dimension, dimensionLength,
|
||||
tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ);
|
||||
|
||||
|
@ -407,9 +407,9 @@ void execReduceFloat(Nd4jPointer *extraPointers,
|
|||
|
||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||
NativeOpExecutioner::execReduceFloatScalar(&lc, opNum,
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
||||
extraParams,
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT<Nd4jLong>());
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special());
|
||||
|
||||
InteropDataBuffer::registerSpecialUse({dbZ}, {dbX});
|
||||
} catch (std::exception &e) {
|
||||
|
@ -429,9 +429,9 @@ void execReduceSame(Nd4jPointer *extraPointers,
|
|||
|
||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||
NativeOpExecutioner::execReduceSameScalar(&lc, opNum,
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
||||
extraParams,
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT<Nd4jLong>());
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special());
|
||||
|
||||
InteropDataBuffer::registerSpecialUse({dbZ}, {dbX});
|
||||
} catch (std::exception &e) {
|
||||
|
@ -454,15 +454,15 @@ void execReduceSame2(Nd4jPointer *extraPointers,
|
|||
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
||||
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
||||
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo,
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo,
|
||||
dimension,
|
||||
shape::length(hDimensionShape));
|
||||
|
||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||
NativeOpExecutioner::execReduceSame(&lc, opNum,
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
||||
extraParams,
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(),
|
||||
dimension, dimensionLength,
|
||||
tadPack.specialShapeInfo(), tadPack.specialOffsets());
|
||||
|
||||
|
@ -487,15 +487,15 @@ void execReduceLong2(Nd4jPointer *extraPointers,
|
|||
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
||||
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
||||
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo,
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo,
|
||||
dimension,
|
||||
shape::length(hDimensionShape));
|
||||
|
||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||
NativeOpExecutioner::execReduceLong(&lc, opNum,
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
||||
extraParams,
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(),
|
||||
dimension, dimensionLength,
|
||||
tadPack.specialShapeInfo(), tadPack.specialOffsets());
|
||||
|
||||
|
@ -534,9 +534,9 @@ void execReduceLong(Nd4jPointer *extraPointers,
|
|||
|
||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction,
|
||||
::execReduceScalar(launchDims, stream, opNum,
|
||||
dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT<Nd4jLong>(), hXShapeInfo,
|
||||
dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), hXShapeInfo,
|
||||
extraParams,
|
||||
dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT<Nd4jLong>(), hXShapeInfo,
|
||||
dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(), hXShapeInfo,
|
||||
nullptr, 0, reductionPointer, dTADShapeInfo), LIBND4J_TYPES, LONG_TYPES);
|
||||
|
||||
sd::DebugHelper::checkErrorCode(stream, "execReduceLong(...) failed");
|
||||
|
@ -562,15 +562,15 @@ void execReduceBool2(Nd4jPointer *extraPointers,
|
|||
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
||||
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
||||
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo,
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo,
|
||||
dimension,
|
||||
shape::length(hDimensionShape));
|
||||
|
||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||
NativeOpExecutioner::execReduceBool(&lc, opNum,
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
||||
extraParams,
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(),
|
||||
dimension, dimensionLength,
|
||||
tadPack.specialShapeInfo(), tadPack.specialOffsets());
|
||||
|
||||
|
@ -609,9 +609,9 @@ void execReduceBool(Nd4jPointer *extraPointers,
|
|||
|
||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction,
|
||||
::execReduceScalar(launchDims, stream, opNum,
|
||||
dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT<Nd4jLong>(), hXShapeInfo,
|
||||
dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), hXShapeInfo,
|
||||
extraParams,
|
||||
dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT<Nd4jLong>(), hZShapeInfo,
|
||||
dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(), hZShapeInfo,
|
||||
nullptr, 0, reductionPointer, dTADShapeInfo), LIBND4J_TYPES, BOOL_TYPES);
|
||||
|
||||
sd::DebugHelper::checkErrorCode(stream, "execReduceBool(...) failed");
|
||||
|
@ -648,15 +648,15 @@ void execIndexReduce(Nd4jPointer *extraPointers,
|
|||
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
||||
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
||||
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo,
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo,
|
||||
dimension,
|
||||
shape::length(hDimensionShape));
|
||||
|
||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||
NativeOpExecutioner::execIndexReduce(&lc, opNum,
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
||||
extraParams,
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(),
|
||||
(int *) dbDimension->special(), dimensionLength,
|
||||
tadPack.specialShapeInfo(), tadPack.specialOffsets());
|
||||
|
||||
|
@ -690,15 +690,15 @@ void execReduceFloat2(Nd4jPointer *extraPointers,
|
|||
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
||||
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
||||
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo,
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo,
|
||||
dimension,
|
||||
shape::length(hDimensionShape));
|
||||
|
||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||
NativeOpExecutioner::execReduceFloat(&lc, opNum,
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
||||
extraParams,
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(),
|
||||
dimension, dimensionLength,
|
||||
tadPack.specialShapeInfo(), tadPack.specialOffsets());
|
||||
|
||||
|
@ -728,9 +728,9 @@ void execIndexReduceScalar(
|
|||
|
||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||
NativeOpExecutioner::execIndexReduceScalar(&lc, opNum,
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
||||
extraParams,
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT<Nd4jLong>());
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special());
|
||||
|
||||
InteropDataBuffer::registerSpecialUse({dbZ}, {dbX});
|
||||
} catch (std::exception &e) {
|
||||
|
@ -752,8 +752,8 @@ void execTransformSame(Nd4jPointer *extraPointers,int opNum,
|
|||
|
||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||
NativeOpExecutioner::execTransformSame(&lc, opNum,
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(),
|
||||
extraParams,
|
||||
tadShapeInfo, tadOffsets);
|
||||
|
||||
|
@ -777,8 +777,8 @@ void execTransformBool(Nd4jPointer *extraPointers,int opNum,
|
|||
|
||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||
NativeOpExecutioner::execTransformBool(&lc, opNum,
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(),
|
||||
extraParams,
|
||||
tadShapeInfo, tadOffsets);
|
||||
|
||||
|
@ -803,8 +803,8 @@ void execTransformAny(Nd4jPointer *extraPointers,int opNum,
|
|||
reinterpret_cast<int *>(extraPointers[6]));
|
||||
|
||||
NativeOpExecutioner::execTransformAny(&lc, opNum,
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(),
|
||||
extraParams,
|
||||
nullptr, nullptr);
|
||||
|
||||
|
@ -828,8 +828,8 @@ void execTransformStrict(Nd4jPointer *extraPointers,int opNum,
|
|||
|
||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||
NativeOpExecutioner::execTransformStrict(&lc, opNum,
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(),
|
||||
extraParams,
|
||||
tadShapeInfo, tadOffsets);
|
||||
|
||||
|
@ -853,8 +853,8 @@ void execTransformFloat(Nd4jPointer *extraPointers,int opNum,
|
|||
|
||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||
NativeOpExecutioner::execTransformFloat(&lc, opNum,
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(),
|
||||
extraParams,
|
||||
tadShapeInfo, tadOffsets);
|
||||
|
||||
|
@ -939,7 +939,7 @@ void enableP2P(bool enable) {
|
|||
cudaDeviceDisablePeerAccess(dY);
|
||||
}
|
||||
} else {
|
||||
if (sd::Environment::getInstance()->isVerbose()) printf("Peer access [%i] -> [%i] isn't possible\n", dX, dY);
|
||||
if (sd::Environment::getInstance().isVerbose()) printf("Peer access [%i] -> [%i] isn't possible\n", dX, dY);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -983,7 +983,7 @@ void initializeDevicesAndFunctions() {
|
|||
}
|
||||
|
||||
void initializeFunctions(Nd4jPointer *functions) {
|
||||
sd::BlasHelper::getInstance()->initializeDeviceFunctions(functions);
|
||||
sd::BlasHelper::getInstance().initializeDeviceFunctions(functions);
|
||||
/*
|
||||
cublasSgemv = (CublasSgemv)functions[0];
|
||||
cublasDgemv = (CublasDgemv)functions[1];
|
||||
|
@ -1317,7 +1317,7 @@ int getAvailableDevices() {
|
|||
}
|
||||
|
||||
void enableDebugMode(bool reallyEnable) {
|
||||
sd::Environment::getInstance()->setDebug(reallyEnable);
|
||||
sd::Environment::getInstance().setDebug(reallyEnable);
|
||||
}
|
||||
|
||||
void setGridLimit(int gridSize) {
|
||||
|
@ -1345,7 +1345,7 @@ void setOmpNumThreads(int threads) {
|
|||
}
|
||||
|
||||
void enableVerboseMode(bool reallyEnable) {
|
||||
sd::Environment::getInstance()->setVerbose(reallyEnable);
|
||||
sd::Environment::getInstance().setVerbose(reallyEnable);
|
||||
}
|
||||
|
||||
int getDeviceMajor(int device) {
|
||||
|
@ -1386,7 +1386,7 @@ void specialConcat(
|
|||
sd::TadPack* tadOnlyShapeInfo(Nd4jLong const* dXShapeInfo, int *dimension, int dimensionLength) {
|
||||
try {
|
||||
auto pack = new TadPack();
|
||||
*pack = sd::ConstantTadHelper::getInstance()->tadForDimensions(dXShapeInfo, dimension, dimensionLength);
|
||||
*pack = sd::ConstantTadHelper::getInstance().tadForDimensions(dXShapeInfo, dimension, dimensionLength);
|
||||
return pack;
|
||||
} catch (std::exception &e) {
|
||||
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
||||
|
@ -1502,7 +1502,7 @@ void average(Nd4jPointer *extras,
|
|||
|
||||
auto dX = reinterpret_cast<void **>(dx);
|
||||
|
||||
if (sd::Environment::getInstance()->isDebugAndVerbose())
|
||||
if (sd::Environment::getInstance().isDebugAndVerbose())
|
||||
printf("averageFloat called\n");
|
||||
|
||||
auto xType = sd::ArrayOptions::dataType(xShapeInfo);
|
||||
|
@ -1536,7 +1536,7 @@ void accumulate(Nd4jPointer *extras,
|
|||
|
||||
auto dX = reinterpret_cast<void **>(dx);
|
||||
|
||||
if (sd::Environment::getInstance()->isDebugAndVerbose())
|
||||
if (sd::Environment::getInstance().isDebugAndVerbose())
|
||||
printf("accumulateFloat called\n");
|
||||
auto xType = sd::ArrayOptions::dataType(xShapeInfo);
|
||||
|
||||
|
@ -1591,7 +1591,7 @@ void shuffle(Nd4jPointer *extras,
|
|||
}
|
||||
|
||||
bool isExperimentalEnabled() {
|
||||
return sd::Environment::getInstance()->isExperimentalBuild();
|
||||
return sd::Environment::getInstance().isExperimentalBuild();
|
||||
}
|
||||
|
||||
void setOmpMinThreads(int threads) {
|
||||
|
@ -1623,9 +1623,9 @@ void execSummaryStats(Nd4jPointer *extraPointers,
|
|||
|
||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||
NativeOpExecutioner::execSummaryStats(&lc, opNum,
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
||||
extraParams,
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(),
|
||||
biasCorrected);
|
||||
|
||||
InteropDataBuffer::registerSpecialUse({dbZ}, {dbX});
|
||||
|
@ -1653,9 +1653,9 @@ void execSummaryStatsTad(Nd4jPointer *extraPointers,
|
|||
|
||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||
NativeOpExecutioner::execSummaryStats(&lc, opNum,
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
||||
extraParams,
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(),
|
||||
reinterpret_cast<int *>(dbDimension->special()), dimensionLength,
|
||||
tadShapeInfo, tadOffsets,
|
||||
biasCorrected);
|
||||
|
@ -1679,10 +1679,10 @@ void execReduce3(Nd4jPointer *extraPointers,
|
|||
|
||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||
NativeOpExecutioner::execReduce3(&lc, opNum,
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
||||
extraParams,
|
||||
dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT<Nd4jLong>());
|
||||
dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hYShapeInfo).special(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special());
|
||||
|
||||
InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY});
|
||||
} catch (std::exception &e) {
|
||||
|
@ -1708,7 +1708,7 @@ void execReduce3Tad(Nd4jPointer *extraPointers,
|
|||
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
||||
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
||||
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo,
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo,
|
||||
dimension,
|
||||
shape::length(hDimensionShape));
|
||||
auto tadLength = shape::length(tadPack.primaryShapeInfo());
|
||||
|
@ -1720,18 +1720,18 @@ void execReduce3Tad(Nd4jPointer *extraPointers,
|
|||
if (tadLength == yLength || tadLength == xLength) {
|
||||
// nd4j_printf("== way\n","");
|
||||
NativeOpExecutioner::execReduce3(&lc, opNum,
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
||||
extraParams,
|
||||
dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hYShapeInfo).special(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(),
|
||||
dimension, dimensionLength,
|
||||
tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets);
|
||||
} else
|
||||
NativeOpExecutioner::execReduce3TAD(&lc, opNum,
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
||||
extraParams,
|
||||
dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hYShapeInfo).special(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(),
|
||||
dimension, dimensionLength,
|
||||
tadOnlyShapeInfo, yTadOffsets, yTadOnlyShapeInfo, yTadOffsets);
|
||||
|
||||
|
@ -1753,10 +1753,10 @@ void execReduce3Scalar(Nd4jPointer *extraPointers,int opNum,
|
|||
|
||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||
NativeOpExecutioner::execReduce3Scalar(&lc, opNum,
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
||||
extraParams,
|
||||
dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT<Nd4jLong>());
|
||||
dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hYShapeInfo).special(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special());
|
||||
|
||||
InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY});
|
||||
} catch (std::exception &e) {
|
||||
|
@ -1777,9 +1777,9 @@ void execScalarBool(Nd4jPointer *extraPointers,
|
|||
|
||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||
NativeOpExecutioner::execScalarBool(&lc, opNum,
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbScalar->primary(), hScalarShapeInfo, dbScalar->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hScalarShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(),
|
||||
dbScalar->primary(), hScalarShapeInfo, dbScalar->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hScalarShapeInfo).special(),
|
||||
extraParams);
|
||||
|
||||
InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbScalar});
|
||||
|
@ -1808,10 +1808,10 @@ void execScalarBoolTad(Nd4jPointer *extraPointers,
|
|||
|
||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||
NativeOpExecutioner::execScalarBool(&lc, opNum,
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
||||
extraParams,
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbScalars->primary(), hScalarShapeInfo, dbScalars->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hScalarShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(),
|
||||
dbScalars->primary(), hScalarShapeInfo, dbScalars->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hScalarShapeInfo).special(),
|
||||
dimension, dimensionLength,
|
||||
tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ);
|
||||
|
||||
|
@ -1834,9 +1834,9 @@ void execScalar(Nd4jPointer *extraPointers,
|
|||
|
||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||
NativeOpExecutioner::execScalar(&lc, opNum,
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbScalar->primary(), hScalarShapeInfo, dbScalar->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hScalarShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(),
|
||||
dbScalar->primary(), hScalarShapeInfo, dbScalar->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hScalarShapeInfo).special(),
|
||||
extraParams);
|
||||
|
||||
InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbScalar});
|
||||
|
@ -1877,7 +1877,7 @@ void execScalarTad(Nd4jPointer *extraPointers,
|
|||
#ifdef __ND4J_EXPERIMENTAL__
|
||||
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalars, extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
#else
|
||||
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT<Nd4jLong>(), dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT<Nd4jLong>(), dbScalars->special(), extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(), dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(), dbScalars->special(), extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES);
|
||||
#endif
|
||||
|
||||
DEBUG_KERNEL(stream, opNum);
|
||||
|
@ -1938,7 +1938,7 @@ void execRandom(Nd4jPointer *extraPointers,
|
|||
|
||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||
NativeOpExecutioner::execRandom(&lc, opNum, stateHost,
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(),
|
||||
extraArguments);
|
||||
|
||||
InteropDataBuffer::registerSpecialUse({dbZ}, {});
|
||||
|
@ -1958,8 +1958,8 @@ void execRandom2(Nd4jPointer *extraPointers, int opNum, Nd4jPointer stateHost,
|
|||
|
||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||
NativeOpExecutioner::execRandom(&lc, opNum, stateHost,
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(),
|
||||
extraArguments);
|
||||
|
||||
InteropDataBuffer::registerSpecialUse({dbZ}, {dbX});
|
||||
|
@ -1980,9 +1980,9 @@ void execRandom3(Nd4jPointer *extraPointers, int opNum, Nd4jPointer stateHost,
|
|||
|
||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||
NativeOpExecutioner::execRandom(&lc, opNum, stateHost,
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
||||
dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hYShapeInfo).special(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(),
|
||||
extraArguments);
|
||||
|
||||
InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY});
|
||||
|
@ -2216,10 +2216,10 @@ void execReduce3All(Nd4jPointer *extraPointers,
|
|||
|
||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||
NativeOpExecutioner::execReduce3All(&lc, opNum,
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
||||
extraParamsVals,
|
||||
dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT<Nd4jLong>(),
|
||||
dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hYShapeInfo).special(),
|
||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(),
|
||||
reinterpret_cast<int *>(dbDimension->special()), dimensionLength,
|
||||
xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets);
|
||||
|
||||
|
@ -2458,7 +2458,7 @@ void sortTadByKey(Nd4jPointer *extraPointers,
|
|||
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
|
||||
auto context = extraPointers[0] == 0 ? LaunchContext::defaultContext()
|
||||
: reinterpret_cast<LaunchContext *>(extraPointers[0]);
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
||||
dim3 launchDims((int) tadPack.numberOfTads(), 256, 2048);
|
||||
auto xType = sd::ArrayOptions::dataType(xShapeInfo);
|
||||
auto yType = sd::ArrayOptions::dataType(yShapeInfo);
|
||||
|
@ -2485,7 +2485,7 @@ void sortTadByValue(Nd4jPointer *extraPointers,
|
|||
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
|
||||
auto context = extraPointers[0] == 0 ? LaunchContext::defaultContext()
|
||||
: reinterpret_cast<LaunchContext *>(extraPointers[0]);
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
||||
dim3 launchDims((int) tadPack.numberOfTads(), 256, 2048);
|
||||
auto xType = sd::ArrayOptions::dataType(yShapeInfo);
|
||||
auto yType = sd::ArrayOptions::dataType(xShapeInfo);
|
||||
|
@ -2515,7 +2515,7 @@ void sortTad(Nd4jPointer *extraPointers,
|
|||
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
|
||||
auto context = extraPointers[0] == 0 ? LaunchContext::defaultContext()
|
||||
: reinterpret_cast<LaunchContext *>(extraPointers[0]);
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
||||
dim3 launchDims((int) tadPack.numberOfTads(), 512, 33768);
|
||||
auto xType = sd::ArrayOptions::dataType(xShapeInfo);
|
||||
BUILD_SINGLE_SELECTOR(xType, oesTadGeneric,
|
||||
|
@ -2561,7 +2561,7 @@ Nd4jPointer getResultWrapperPointer(sd::graph::ResultWrapper* ptr) {
|
|||
|
||||
|
||||
const char* getAllCustomOps() {
|
||||
return sd::ops::OpRegistrator::getInstance()->getAllCustomOperations();
|
||||
return sd::ops::OpRegistrator::getInstance().getAllCustomOperations();
|
||||
}
|
||||
|
||||
|
||||
|
@ -2608,7 +2608,7 @@ sd::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, sd::ops::Decla
|
|||
|
||||
sd::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) {
|
||||
try {
|
||||
auto op = sd::ops::OpRegistrator::getInstance()->getOperation(hash);
|
||||
auto op = sd::ops::OpRegistrator::getInstance().getOperation(hash);
|
||||
|
||||
return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs,
|
||||
iArgs, numIArgs, bArgs, numBArgs, dArgs, numDArgs);
|
||||
|
@ -2639,7 +2639,7 @@ sd::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, sd::ops::Decla
|
|||
|
||||
sd::ShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs) {
|
||||
try {
|
||||
auto op = sd::ops::OpRegistrator::getInstance()->getOperation(hash);
|
||||
auto op = sd::ops::OpRegistrator::getInstance().getOperation(hash);
|
||||
|
||||
return _calculateOutputShapes(extraPointers, op, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs);
|
||||
} catch (std::exception &e) {
|
||||
|
@ -2742,7 +2742,7 @@ static FORCEINLINE Nd4jStatus realExec(sd::ops::DeclarableOp* op, Nd4jPointer* e
|
|||
|
||||
int execCustomOp(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace) {
|
||||
try {
|
||||
auto op = sd::ops::OpRegistrator::getInstance()->getOperation(hash);
|
||||
auto op = sd::ops::OpRegistrator::getInstance().getOperation(hash);
|
||||
|
||||
return realExec(op, extraPointers, hash, inputBuffers, inputShapes, numInputs, outputBuffers, outputShapes,
|
||||
numOutputs, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs, isInplace);
|
||||
|
@ -2755,7 +2755,7 @@ int execCustomOp(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBu
|
|||
|
||||
int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer opContext) {
|
||||
try {
|
||||
auto op = sd::ops::OpRegistrator::getInstance()->getOperation(hash);
|
||||
auto op = sd::ops::OpRegistrator::getInstance().getOperation(hash);
|
||||
auto context = reinterpret_cast<Context *>(opContext);
|
||||
|
||||
auto result = op->execute(context);
|
||||
|
@ -2786,7 +2786,7 @@ int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer flat
|
|||
try {
|
||||
auto graph = sd::graph::GraphExecutioner::importFromFlatPointer(flatBufferPointer);
|
||||
|
||||
sd::graph::GraphHolder::getInstance()->registerGraph(graphId, graph);
|
||||
sd::graph::GraphHolder::getInstance().registerGraph(graphId, graph);
|
||||
|
||||
return ND4J_STATUS_OK;
|
||||
} catch (std::exception &e) {
|
||||
|
@ -2798,7 +2798,7 @@ int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer flat
|
|||
|
||||
|
||||
static VariablesSet* executeStoredGraphT(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs) {
|
||||
auto graph = sd::graph::GraphHolder::getInstance()->pullGraph(graphId);
|
||||
auto graph = sd::graph::GraphHolder::getInstance().pullGraph(graphId);
|
||||
auto varSpace = graph->getVariableSpace()->clone();
|
||||
|
||||
std::vector<sd::NDArray*> handles;
|
||||
|
@ -2887,7 +2887,7 @@ void* getVariableBuffer(sd::graph::Variable* variable) {
|
|||
|
||||
int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId) {
|
||||
try {
|
||||
sd::graph::GraphHolder::getInstance()->dropGraphAny(graphId);
|
||||
sd::graph::GraphHolder::getInstance().dropGraphAny(graphId);
|
||||
|
||||
return ND4J_STATUS_OK;
|
||||
} catch (std::exception &e) {
|
||||
|
@ -2929,7 +2929,7 @@ void deleteShapeList(Nd4jPointer shapeList) {
|
|||
}
|
||||
|
||||
const char* getAllOperations() {
|
||||
return sd::OpTracker::getInstance()->exportOperations();
|
||||
return sd::OpTracker::getInstance().exportOperations();
|
||||
}
|
||||
|
||||
Nd4jPointer getGraphState(Nd4jLong id) {
|
||||
|
@ -3360,7 +3360,7 @@ void tryPointer(Nd4jPointer extra, Nd4jPointer p, int len) {
|
|||
cudaStream_t stream;
|
||||
cudaStreamCreate(&stream);
|
||||
|
||||
tryPointerKernel << < 256, 512, len + 64, stream >> > (p, len);
|
||||
tryPointerKernel <<< 256, 512, len + 64, stream>>> (p, len);
|
||||
auto e = cudaStreamSynchronize(stream);
|
||||
|
||||
if (e != 0)
|
||||
|
@ -3376,10 +3376,11 @@ void tryPointer(Nd4jPointer extra, Nd4jPointer p, int len) {
|
|||
int dataTypeFromNpyHeader(void *header) {
|
||||
return (int) cnpy::dataTypeFromHeader(reinterpret_cast<char *>(header));
|
||||
}
|
||||
sd::ConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strides, sd::DataType dtype, char order, Nd4jLong ews, bool empty) {
|
||||
|
||||
OpaqueConstantShapeBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strides, sd::DataType dtype, char order, Nd4jLong ews, bool empty) {
|
||||
try {
|
||||
auto buffer = new ConstantDataBuffer();
|
||||
*buffer = sd::ConstantShapeHelper::getInstance()->bufferForShapeInfo(
|
||||
auto buffer = new ConstantShapeBuffer();
|
||||
*buffer = sd::ConstantShapeHelper::getInstance().bufferForShapeInfo(
|
||||
ShapeDescriptor(dtype, order, shape, strides, rank, ews, empty));
|
||||
return buffer;
|
||||
} catch (std::exception &e) {
|
||||
|
@ -3389,19 +3390,23 @@ sd::ConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strides
|
|||
}
|
||||
}
|
||||
|
||||
void deleteShapeBuffer(sd::ConstantDataBuffer* ptr) {
|
||||
void deleteConstantShapeBuffer(OpaqueConstantShapeBuffer* ptr) {
|
||||
delete ptr;
|
||||
}
|
||||
|
||||
void deleteConstantDataBuffer(OpaqueConstantDataBuffer* ptr) {
|
||||
delete ptr;
|
||||
}
|
||||
|
||||
void deleteTadPack(sd::TadPack* ptr) {
|
||||
delete ptr;
|
||||
}
|
||||
|
||||
bool isBlasVersionMatches(int major, int minor, int build) {
|
||||
auto result = major == Environment::getInstance()->_blasMajorVersion && minor == Environment::getInstance()->_blasMinorVersion && build == Environment::getInstance()->_blasPatchVersion;
|
||||
auto result = major == Environment::getInstance()._blasMajorVersion && minor == Environment::getInstance()._blasMinorVersion && build == Environment::getInstance()._blasPatchVersion;
|
||||
|
||||
if (!result) {
|
||||
nd4j_printf("CUDA/cuBLAS version mismatch. Expected: %i.%i.%i but got %i.%i.%i instead\n", Environment::getInstance()->_blasMajorVersion, Environment::getInstance()->_blasMinorVersion, Environment::getInstance()->_blasPatchVersion, major, minor, build);
|
||||
nd4j_printf("CUDA/cuBLAS version mismatch. Expected: %i.%i.%i but got %i.%i.%i instead\n", Environment::getInstance()._blasMajorVersion, Environment::getInstance()._blasMinorVersion, Environment::getInstance()._blasPatchVersion, major, minor, build);
|
||||
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(152);
|
||||
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("CUDA/cuBLAS version mismatch");
|
||||
}
|
||||
|
@ -3410,15 +3415,15 @@ bool isBlasVersionMatches(int major, int minor, int build) {
|
|||
}
|
||||
|
||||
sd::ConstantDataBuffer* constantBufferLong(sd::DataType dtype, Nd4jLong const* data, int length) {
|
||||
return sd::ConstantHelper::getInstance()->constantBuffer(ConstantDescriptor(data, length), dtype);
|
||||
return sd::ConstantHelper::getInstance().constantBuffer(ConstantDescriptor(data, length), dtype);
|
||||
}
|
||||
|
||||
sd::ConstantDataBuffer* constantBufferDouble(sd::DataType dtype, double *data, int length) {
|
||||
return sd::ConstantHelper::getInstance()->constantBuffer(ConstantDescriptor(data, length), dtype);
|
||||
return sd::ConstantHelper::getInstance().constantBuffer(ConstantDescriptor(data, length), dtype);
|
||||
}
|
||||
|
||||
sd::ConstantDataBuffer* constantBuffer(sd::DataType dtype, sd::ConstantDescriptor *descriptor) {
|
||||
return sd::ConstantHelper::getInstance()->constantBuffer(*descriptor, dtype);
|
||||
return sd::ConstantHelper::getInstance().constantBuffer(*descriptor, dtype);
|
||||
}
|
||||
|
||||
|
||||
|
@ -3435,6 +3440,13 @@ Nd4jLong getConstantDataBufferSizeOf(sd::ConstantDataBuffer* dbf) {
|
|||
return dbf->sizeOf();
|
||||
}
|
||||
|
||||
Nd4jPointer getConstantShapeBufferPrimary(sd::ConstantShapeBuffer* dbf) {
|
||||
return const_cast<Nd4jLong*>(dbf->primary());
|
||||
}
|
||||
|
||||
Nd4jPointer getConstantShapeBufferSpecial(sd::ConstantShapeBuffer* dbf) {
|
||||
return const_cast<Nd4jLong*>(dbf->special());
|
||||
}
|
||||
|
||||
sd::graph::Context* createGraphContext(int nodeId) {
|
||||
return new sd::graph::Context(nodeId);
|
||||
|
@ -3563,7 +3575,7 @@ Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray) {
|
|||
} else {
|
||||
shapeBuffer = sd::ShapeBuilders::createShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape);
|
||||
}
|
||||
return (Nd4jPointer)(sd::ConstantShapeHelper::getInstance()->createFromExisting(shapeBuffer, true)); // TO DO: this can lead to unpleasant crash sometimes
|
||||
return (Nd4jPointer)(sd::ConstantShapeHelper::getInstance().createFromExisting(shapeBuffer, true)); // TO DO: this can lead to unpleasant crash sometimes
|
||||
} catch (std::exception &e) {
|
||||
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
||||
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
||||
|
@ -3612,7 +3624,7 @@ const char* runFullBenchmarkSuit(bool printOut) {
|
|||
}
|
||||
|
||||
Nd4jLong getCachedMemory(int deviceId) {
|
||||
return sd::ConstantHelper::getInstance()->getCachedAmount(deviceId);
|
||||
return sd::ConstantHelper::getInstance().getCachedAmount(deviceId);
|
||||
}
|
||||
|
||||
sd::LaunchContext* defaultLaunchContext() {
|
||||
|
|
|
@ -214,11 +214,9 @@ namespace sd {
|
|||
_maxDeviceMemory = maxBytes;
|
||||
}
|
||||
|
||||
Environment *Environment::getInstance() {
|
||||
if (_instance == 0)
|
||||
_instance = new Environment();
|
||||
|
||||
return _instance;
|
||||
Environment& Environment::getInstance() {
|
||||
static Environment instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
bool Environment::isVerbose() {
|
||||
|
@ -353,27 +351,27 @@ namespace sd {
|
|||
}
|
||||
|
||||
void Environment::setGroupLimit(int group, Nd4jLong numBytes) {
|
||||
sd::memory::MemoryCounter::getInstance()->setGroupLimit((sd::memory::MemoryType) group, numBytes);
|
||||
sd::memory::MemoryCounter::getInstance().setGroupLimit((sd::memory::MemoryType) group, numBytes);
|
||||
}
|
||||
|
||||
void Environment::setDeviceLimit(int deviceId, Nd4jLong numBytes) {
|
||||
sd::memory::MemoryCounter::getInstance()->setDeviceLimit(deviceId, numBytes);
|
||||
sd::memory::MemoryCounter::getInstance().setDeviceLimit(deviceId, numBytes);
|
||||
}
|
||||
|
||||
Nd4jLong Environment::getGroupLimit(int group) {
|
||||
return sd::memory::MemoryCounter::getInstance()->groupLimit((sd::memory::MemoryType) group);
|
||||
return sd::memory::MemoryCounter::getInstance().groupLimit((sd::memory::MemoryType) group);
|
||||
}
|
||||
|
||||
Nd4jLong Environment::getDeviceLimit(int deviceId) {
|
||||
return sd::memory::MemoryCounter::getInstance()->deviceLimit(deviceId);
|
||||
return sd::memory::MemoryCounter::getInstance().deviceLimit(deviceId);
|
||||
}
|
||||
|
||||
Nd4jLong Environment::getGroupCounter(int group) {
|
||||
return sd::memory::MemoryCounter::getInstance()->allocatedGroup((sd::memory::MemoryType) group);
|
||||
return sd::memory::MemoryCounter::getInstance().allocatedGroup((sd::memory::MemoryType) group);
|
||||
}
|
||||
|
||||
Nd4jLong Environment::getDeviceCounter(int deviceId) {
|
||||
return sd::memory::MemoryCounter::getInstance()->allocatedDevice(deviceId);
|
||||
return sd::memory::MemoryCounter::getInstance().allocatedDevice(deviceId);
|
||||
}
|
||||
|
||||
uint64_t Environment::maxPrimaryMemory() {
|
||||
|
@ -383,7 +381,4 @@ namespace sd {
|
|||
uint64_t Environment::maxSpecialMemory() {
|
||||
return _maxTotalSpecialMemory.load();
|
||||
}
|
||||
|
||||
sd::Environment *sd::Environment::_instance = 0;
|
||||
|
||||
}
|
||||
|
|
|
@ -103,7 +103,7 @@ namespace broadcast {
|
|||
auto tadOffsets = xTadOffset;
|
||||
|
||||
if (xTadShapeInfo == nullptr || tadOffsets == nullptr) {
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
||||
|
||||
xTadShapeShapeInfo = tadPack.primaryShapeInfo();
|
||||
tadOffsets = tadPack.primaryOffsets();
|
||||
|
@ -396,7 +396,7 @@ namespace broadcast {
|
|||
auto tadOffsets = yTadOffset;
|
||||
|
||||
if (yTadShapeInfo == nullptr || tadOffsets == nullptr) {
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(yShapeInfo, dimension, dimensionLength);
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(yShapeInfo, dimension, dimensionLength);
|
||||
|
||||
yTadShapeShapeInfo = tadPack.primaryShapeInfo();
|
||||
tadOffsets = tadPack.primaryOffsets();
|
||||
|
@ -416,7 +416,7 @@ namespace broadcast {
|
|||
|
||||
int tadsPerThread = tads / TAD_THRESHOLD;
|
||||
int threads = sd::math::nd4j_max<int>(1, tadsPerThread);
|
||||
threads = sd::math::nd4j_min<int>(threads, sd::Environment::getInstance()->maxThreads());
|
||||
threads = sd::math::nd4j_min<int>(threads, sd::Environment::getInstance().maxThreads());
|
||||
|
||||
auto yEws = shape::elementWiseStride(yTadShapeShapeInfo);
|
||||
auto xEws = shape::elementWiseStride(xShapeInfo);
|
||||
|
|
|
@ -115,7 +115,7 @@ namespace broadcast {
|
|||
auto tadOffsets = xTadOffset;
|
||||
|
||||
if (xTadShapeInfo == nullptr || tadOffsets == nullptr) {
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
||||
|
||||
xTadShapeShapeInfo = const_cast<Nd4jLong*>(tadPack.primaryShapeInfo());
|
||||
tadOffsets = const_cast<Nd4jLong*>(tadPack.primaryOffsets());
|
||||
|
@ -135,7 +135,7 @@ namespace broadcast {
|
|||
|
||||
int tadsPerThread = tads / TAD_THRESHOLD;
|
||||
int threads = sd::math::nd4j_max<int>(1, tadsPerThread);
|
||||
threads = sd::math::nd4j_min<int>(threads, sd::Environment::getInstance()->maxThreads());
|
||||
threads = sd::math::nd4j_min<int>(threads, sd::Environment::getInstance().maxThreads());
|
||||
|
||||
auto xEws = shape::elementWiseStride(xTadShapeShapeInfo);
|
||||
auto yEws = shape::elementWiseStride(yShapeInfo);
|
||||
|
@ -280,7 +280,7 @@ namespace broadcast {
|
|||
auto tadOffsets = yTadOffset;
|
||||
|
||||
if (yTadShapeInfo == nullptr || tadOffsets == nullptr) {
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(yShapeInfo, dimension, dimensionLength);
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(yShapeInfo, dimension, dimensionLength);
|
||||
|
||||
yTadShapeShapeInfo = const_cast<Nd4jLong*>(tadPack.primaryShapeInfo());
|
||||
tadOffsets = const_cast<Nd4jLong*>(tadPack.primaryOffsets());
|
||||
|
@ -300,7 +300,7 @@ namespace broadcast {
|
|||
|
||||
int tadsPerThread = tads / TAD_THRESHOLD;
|
||||
int threads = sd::math::nd4j_max<int>(1, tadsPerThread);
|
||||
threads = sd::math::nd4j_min<int>(threads, sd::Environment::getInstance()->maxThreads());
|
||||
threads = sd::math::nd4j_min<int>(threads, sd::Environment::getInstance().maxThreads());
|
||||
|
||||
auto yEws = shape::elementWiseStride(yTadShapeShapeInfo);
|
||||
auto xEws = shape::elementWiseStride(xShapeInfo);
|
||||
|
|
|
@ -108,7 +108,7 @@ namespace functions {
|
|||
auto tadOffsets = xTadOffset;
|
||||
|
||||
if (xTadShapeInfo == nullptr || tadOffsets == nullptr) {
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
||||
|
||||
xTadShapeShapeInfo = const_cast<Nd4jLong*>(tadPack.primaryShapeInfo());
|
||||
tadOffsets = const_cast<Nd4jLong*>(tadPack.primaryOffsets());
|
||||
|
@ -128,7 +128,7 @@ namespace functions {
|
|||
|
||||
int tadsPerThread = tads / TAD_THRESHOLD;
|
||||
int threads = sd::math::nd4j_max<int>(1, tadsPerThread);
|
||||
threads = sd::math::nd4j_min<int>(threads, sd::Environment::getInstance()->maxThreads());
|
||||
threads = sd::math::nd4j_min<int>(threads, sd::Environment::getInstance().maxThreads());
|
||||
|
||||
auto xEws = shape::elementWiseStride(xTadShapeShapeInfo);
|
||||
auto yEws = shape::elementWiseStride(yShapeInfo);
|
||||
|
@ -271,7 +271,7 @@ namespace functions {
|
|||
auto tadOffsets = yTadOffset;
|
||||
|
||||
if (yTadShapeInfo == nullptr || tadOffsets == nullptr) {
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(yShapeInfo, dimension, dimensionLength);
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(yShapeInfo, dimension, dimensionLength);
|
||||
|
||||
yTadShapeShapeInfo = const_cast<Nd4jLong*>(tadPack.primaryShapeInfo());
|
||||
tadOffsets = const_cast<Nd4jLong*>(tadPack.primaryOffsets());
|
||||
|
@ -291,7 +291,7 @@ namespace functions {
|
|||
|
||||
int tadsPerThread = tads / TAD_THRESHOLD;
|
||||
int threads = sd::math::nd4j_max<int>(1, tadsPerThread);
|
||||
threads = sd::math::nd4j_min<int>(threads, sd::Environment::getInstance()->maxThreads());
|
||||
threads = sd::math::nd4j_min<int>(threads, sd::Environment::getInstance().maxThreads());
|
||||
|
||||
auto yEws = shape::elementWiseStride(yTadShapeShapeInfo);
|
||||
auto xEws = shape::elementWiseStride(xShapeInfo);
|
||||
|
|
|
@ -64,7 +64,7 @@ Nd4jLong IndexReduce<X, Y>::execScalar(const void *vx, const Nd4jLong *xShapeInf
|
|||
|
||||
uint xShapeInfoCast[MAX_RANK];
|
||||
bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
|
||||
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance().maxThreads());
|
||||
IndexValue<X> intermediatery[64];
|
||||
for (int e = 0; e < maxThreads; e++)
|
||||
intermediatery[e].index = -1;
|
||||
|
@ -142,7 +142,7 @@ void IndexReduce<X, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
|
|||
if (dimensionLength < 1)
|
||||
return;
|
||||
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
||||
|
||||
tadOnlyShapeInfo = tadPack.primaryShapeInfo();
|
||||
tadOffsets = tadPack.primaryOffsets();
|
||||
|
|
|
@ -166,7 +166,7 @@ namespace functions {
|
|||
if (dimensionLength < 1)
|
||||
return;
|
||||
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
||||
tadOnlyShapeInfo = tadPack.primaryShapeInfo();
|
||||
tadOffsets = tadPack.primaryOffsets();
|
||||
}
|
||||
|
@ -193,7 +193,7 @@ namespace functions {
|
|||
Z _CUDA_H ReduceBoolFunction<X, Z>::execScalar(const void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) {
|
||||
auto x = reinterpret_cast<const X *>(vx);
|
||||
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||||
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
|
||||
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance().maxThreads());
|
||||
Z intermediate[64];
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
|
|
|
@ -70,7 +70,7 @@ namespace functions {
|
|||
auto startingValue = OpType::startingValue(x);
|
||||
uint xShapeInfoCast[MAX_RANK];
|
||||
const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
|
||||
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance().maxThreads());
|
||||
Z intermediate[64];
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
|
@ -200,7 +200,7 @@ namespace functions {
|
|||
if (dimensionLength < 0)
|
||||
return;
|
||||
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
||||
tadOnlyShapeInfo = tadPack.primaryShapeInfo();
|
||||
tadOffsets = tadPack.primaryOffsets();
|
||||
}
|
||||
|
@ -229,7 +229,7 @@ namespace functions {
|
|||
|
||||
auto x = reinterpret_cast<const X *>(vx);
|
||||
auto extraParams = reinterpret_cast<Z *>(vextraParams);
|
||||
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
|
||||
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance().maxThreads());
|
||||
Z intermediate[64];
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
|
|
|
@ -65,7 +65,7 @@ namespace functions {
|
|||
auto startingValue = OpType::startingValue(x);
|
||||
uint xShapeInfoCast[MAX_RANK];
|
||||
const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
|
||||
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance().maxThreads());
|
||||
Z intermediate[64];
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
|
@ -187,7 +187,7 @@ namespace functions {
|
|||
if (dimensionLength < 1)
|
||||
return;
|
||||
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
||||
tadOnlyShapeInfo = tadPack.primaryShapeInfo();
|
||||
tadOffsets = tadPack.primaryOffsets();
|
||||
}
|
||||
|
@ -215,7 +215,7 @@ namespace functions {
|
|||
|
||||
auto x = reinterpret_cast<const X *>(vx);
|
||||
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||||
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
|
||||
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance().maxThreads());
|
||||
Z intermediate[64];
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
|
|
|
@ -67,7 +67,7 @@ namespace functions {
|
|||
auto startingValue = OpType::startingValue(x);
|
||||
uint xShapeInfoCast[MAX_RANK];
|
||||
const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
|
||||
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance().maxThreads());
|
||||
X intermediate[64];
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
|
@ -196,7 +196,7 @@ namespace functions {
|
|||
if (dimensionLength < 1)
|
||||
return;
|
||||
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
||||
tadOnlyShapeInfo = tadPack.primaryShapeInfo();
|
||||
tadOffsets = tadPack.primaryOffsets();
|
||||
}
|
||||
|
@ -224,7 +224,7 @@ namespace functions {
|
|||
|
||||
auto x = reinterpret_cast<const X *>(vx);
|
||||
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||||
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
|
||||
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance().maxThreads());
|
||||
X intermediate[64];
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
|
|
|
@ -65,7 +65,7 @@ void Reduce3<X,Z>::execScalar(const void *vx, const Nd4jLong *xShapeInfo,
|
|||
const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||
|
||||
Z startingVal = OpType::startingValue(x);
|
||||
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance()->maxThreads());
|
||||
int maxThreads = sd::math::nd4j_min<int>(64, sd::Environment::getInstance().maxThreads());
|
||||
Z intermediate[64];
|
||||
Z extraParamsLocal[3 * 64];
|
||||
|
||||
|
|
|
@ -65,7 +65,7 @@ void ScalarTransform<X, Y, Z>::transform(const void *vx, const Nd4jLong *xShapeI
|
|||
return;
|
||||
}
|
||||
|
||||
int num_threads = sd::math::nd4j_min<int>(numTads, sd::Environment::getInstance()->maxThreads());
|
||||
int num_threads = sd::math::nd4j_min<int>(numTads, sd::Environment::getInstance().maxThreads());
|
||||
|
||||
if (kindOfLoop == sd::LoopKind::EWS1) {
|
||||
for (auto r = start; r < stop; r++) {
|
||||
|
|
|
@ -66,7 +66,7 @@ namespace functions {
|
|||
return;
|
||||
}
|
||||
|
||||
int num_threads = sd::math::nd4j_min<int>(numTads, sd::Environment::getInstance()->maxThreads());
|
||||
int num_threads = sd::math::nd4j_min<int>(numTads, sd::Environment::getInstance().maxThreads());
|
||||
|
||||
if (kindOfLoop == sd::LoopKind::EWS1) {
|
||||
for (auto r = start; r < stop; r++) {
|
||||
|
|
|
@ -66,7 +66,7 @@ namespace functions {
|
|||
return;
|
||||
}
|
||||
|
||||
int num_threads = sd::math::nd4j_min<int>(numTads, sd::Environment::getInstance()->maxThreads());
|
||||
int num_threads = sd::math::nd4j_min<int>(numTads, sd::Environment::getInstance().maxThreads());
|
||||
|
||||
if (kindOfLoop == sd::LoopKind::EWS1) {
|
||||
for (auto r = start; r < stop; r++) {
|
||||
|
|
|
@ -127,7 +127,7 @@ namespace functions {
|
|||
if (dimensionLength < 1)
|
||||
return;
|
||||
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
||||
|
||||
//pre squeezed: this is for keeping the pointer to the original
|
||||
//shape information for tad offset
|
||||
|
|
|
@ -173,7 +173,7 @@ namespace functions {
|
|||
|
||||
DISPATCH_SIMPLE(transformShaped, float16, PARAMS(x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), OPS_A(TRANSFORM_OPS))
|
||||
|
||||
if (sd::Environment::getInstance()->isDebug())
|
||||
if (sd::Environment::getInstance().isDebug())
|
||||
checkCudaErrors(cudaStreamSynchronize(*stream));
|
||||
}
|
||||
|
||||
|
|
|
@ -152,7 +152,7 @@ void _CUDA_H ScalarTransform<X,Y,Z>::intermediateAlongDimension(dim3& launchDims
|
|||
template<typename X, typename Y, typename Z>
|
||||
void ScalarTransform<X,Y,Z>::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void const* vx, Nd4jLong const* xShapeInfo, Nd4jLong const* hxShapeInfo, void *vz, Nd4jLong const* zShapeInfo, Nd4jLong const* hzShapeInfo, void const* vscalar, void *vextraParams) {
|
||||
|
||||
if (sd::Environment::getInstance()->isDebugAndVerbose())
|
||||
if (sd::Environment::getInstance().isDebugAndVerbose())
|
||||
printf("H14 opNum:[%i]\n", opNum);
|
||||
|
||||
DISPATCH_BY_OPNUM_TTT(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, hxShapeInfo, vz, zShapeInfo, hzShapeInfo, vscalar, vextraParams, nullptr), SCALAR_OPS);
|
||||
|
|
|
@ -218,7 +218,7 @@ void ScalarBoolTransform<X,Y>::executeCudaShaped(dim3& launchDims, cudaStream_t
|
|||
void const* vscalar,
|
||||
void const* vextraParams) {
|
||||
|
||||
if (sd::Environment::getInstance()->isDebugAndVerbose())
|
||||
if (sd::Environment::getInstance().isDebugAndVerbose())
|
||||
printf("H14 opNum:[%i]\n", opNum);
|
||||
|
||||
DISPATCH_BY_OPNUM_TT(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vz, zShapeInfo, vscalar, const_cast<void*>(vextraParams), nullptr), SCALAR_BOOL_OPS);
|
||||
|
|
|
@ -216,7 +216,7 @@ void ScalarIntTransform<X>::executeCudaShaped(dim3& launchDims, cudaStream_t *st
|
|||
void const* vscalar,
|
||||
void* vextraParams) {
|
||||
|
||||
if (sd::Environment::getInstance()->isDebugAndVerbose())
|
||||
if (sd::Environment::getInstance().isDebugAndVerbose())
|
||||
printf("H14 opNum:[%i]\n", opNum);
|
||||
|
||||
DISPATCH_BY_OPNUM_T(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vz, zShapeInfo, vscalar, vextraParams, nullptr), SCALAR_INT_OPS);
|
||||
|
|
|
@ -344,7 +344,7 @@ void _CUDA_G summaryStatsReduceT(int op, void const* dx, Nd4jLong const* xShapeI
|
|||
auto z = reinterpret_cast<Z*>(vz);
|
||||
auto reductionPointerA = reinterpret_cast<Z*>(reductionBuffer);
|
||||
|
||||
if (sd::Environment::getInstance()->isDebugAndVerbose())
|
||||
if (sd::Environment::getInstance().isDebugAndVerbose())
|
||||
printf("D16 opNum:[%i]\n", opNum);
|
||||
|
||||
summaryStatsReduceT<X,Z><<<launchDims.x,launchDims.y,launchDims.z, *stream>>>(
|
||||
|
@ -369,7 +369,7 @@ void _CUDA_G summaryStatsReduceT(int op, void const* dx, Nd4jLong const* xShapeI
|
|||
auto z = static_cast<Z*>(vz);
|
||||
auto extraParams = static_cast<Z*>(vextraParams);
|
||||
|
||||
if (sd::Environment::getInstance()->isDebugAndVerbose())
|
||||
if (sd::Environment::getInstance().isDebugAndVerbose())
|
||||
printf("F17 opNum:[%i]\n", opNum);
|
||||
|
||||
auto reductionPointerA = reinterpret_cast<Z*>(reductionBuffer);
|
||||
|
@ -396,7 +396,7 @@ void _CUDA_G summaryStatsReduceT(int op, void const* dx, Nd4jLong const* xShapeI
|
|||
auto z = static_cast<Z*>(vz);
|
||||
auto extraParams = static_cast<Z*>(vextraParams);
|
||||
|
||||
if (sd::Environment::getInstance()->isDebugAndVerbose())
|
||||
if (sd::Environment::getInstance().isDebugAndVerbose())
|
||||
printf("D18 opNum:[%i]\n", opNum);
|
||||
|
||||
summaryStatsReduceT<X, Z><<<launchDims.x,launchDims.y,launchDims.z, *stream>>>(
|
||||
|
|
|
@ -34,8 +34,6 @@ namespace sd {
|
|||
*/
|
||||
class ND4J_EXPORT MemoryCounter {
|
||||
private:
|
||||
static MemoryCounter* _INSTANCE;
|
||||
|
||||
// used for synchronization
|
||||
std::mutex _locker;
|
||||
|
||||
|
@ -56,7 +54,7 @@ namespace sd {
|
|||
~MemoryCounter() = default;
|
||||
|
||||
public:
|
||||
static MemoryCounter *getInstance();
|
||||
static MemoryCounter & getInstance();
|
||||
|
||||
/**
|
||||
* This method checks if allocation of numBytes won't break through per-group or per-device limit
|
||||
|
|
|
@ -32,7 +32,6 @@ namespace sd {
|
|||
namespace memory {
|
||||
class ND4J_EXPORT MemoryRegistrator {
|
||||
protected:
|
||||
static MemoryRegistrator* _INSTANCE;
|
||||
Workspace* _workspace;
|
||||
MAP_IMPL<Nd4jLong, Nd4jLong> _footprint;
|
||||
std::mutex _lock;
|
||||
|
@ -40,7 +39,7 @@ namespace sd {
|
|||
MemoryRegistrator();
|
||||
~MemoryRegistrator() = default;
|
||||
public:
|
||||
static MemoryRegistrator* getInstance();
|
||||
static MemoryRegistrator& getInstance();
|
||||
bool hasWorkspaceAttached();
|
||||
Workspace* getWorkspace();
|
||||
void attachWorkspace(Workspace* workspace);
|
||||
|
|
|
@ -35,7 +35,6 @@ namespace sd {
|
|||
*/
|
||||
class ND4J_EXPORT MemoryTracker {
|
||||
private:
|
||||
static MemoryTracker* _INSTANCE;
|
||||
std::map<Nd4jLong, AllocationEntry> _allocations;
|
||||
std::map<Nd4jLong, AllocationEntry> _released;
|
||||
std::mutex _locker;
|
||||
|
@ -43,7 +42,7 @@ namespace sd {
|
|||
MemoryTracker();
|
||||
~MemoryTracker() = default;
|
||||
public:
|
||||
static MemoryTracker* getInstance();
|
||||
static MemoryTracker& getInstance();
|
||||
|
||||
void countIn(MemoryType type, Nd4jPointer ptr, Nd4jLong numBytes);
|
||||
void countOut(Nd4jPointer ptr);
|
||||
|
|
|
@ -36,19 +36,17 @@ namespace sd {
|
|||
}
|
||||
|
||||
// setting initial values for limits
|
||||
_groupLimits[sd::memory::MemoryType::HOST] = sd::Environment::getInstance()->maxPrimaryMemory();
|
||||
_groupLimits[sd::memory::MemoryType::DEVICE] = sd::Environment::getInstance()->maxSpecialMemory();
|
||||
_groupLimits[sd::memory::MemoryType::HOST] = sd::Environment::getInstance().maxPrimaryMemory();
|
||||
_groupLimits[sd::memory::MemoryType::DEVICE] = sd::Environment::getInstance().maxSpecialMemory();
|
||||
|
||||
// setting initial counter values
|
||||
_groupCounters[sd::memory::MemoryType::HOST] = 0;
|
||||
_groupCounters[sd::memory::MemoryType::DEVICE] = 0;
|
||||
}
|
||||
|
||||
MemoryCounter* MemoryCounter::getInstance() {
|
||||
if (_INSTANCE == 0)
|
||||
_INSTANCE = new MemoryCounter();
|
||||
|
||||
return _INSTANCE;
|
||||
MemoryCounter& MemoryCounter::getInstance() {
|
||||
static MemoryCounter instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void MemoryCounter::countIn(int deviceId, Nd4jLong numBytes) {
|
||||
|
@ -127,7 +125,5 @@ namespace sd {
|
|||
std::lock_guard<std::mutex> lock(_locker);
|
||||
return _groupLimits[group];
|
||||
}
|
||||
|
||||
MemoryCounter* MemoryCounter::_INSTANCE = 0;
|
||||
}
|
||||
}
|
|
@ -27,11 +27,9 @@ namespace sd {
|
|||
_workspace = nullptr;
|
||||
};
|
||||
|
||||
MemoryRegistrator* MemoryRegistrator::getInstance() {
|
||||
if (_INSTANCE == 0)
|
||||
_INSTANCE = new MemoryRegistrator();
|
||||
|
||||
return _INSTANCE;
|
||||
MemoryRegistrator& MemoryRegistrator::getInstance() {
|
||||
static MemoryRegistrator instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
bool MemoryRegistrator::hasWorkspaceAttached() {
|
||||
|
@ -83,8 +81,5 @@ namespace sd {
|
|||
|
||||
return result;
|
||||
}
|
||||
|
||||
MemoryRegistrator* MemoryRegistrator::_INSTANCE = 0;
|
||||
|
||||
}
|
||||
}
|
|
@ -40,11 +40,9 @@ namespace sd {
|
|||
//
|
||||
}
|
||||
|
||||
MemoryTracker* MemoryTracker::getInstance() {
|
||||
if (_INSTANCE == 0)
|
||||
_INSTANCE = new MemoryTracker();
|
||||
|
||||
return _INSTANCE;
|
||||
MemoryTracker& MemoryTracker::getInstance() {
|
||||
static MemoryTracker instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
#if defined(__GNUC__) && !defined(__MINGW64__) && !defined(SD_ANDROID_BUILD) && !defined(SD_IOS_BUILD) && !defined(SD_APPLE_BUILD)
|
||||
|
@ -99,7 +97,7 @@ namespace sd {
|
|||
|
||||
void MemoryTracker::countIn(MemoryType type, Nd4jPointer ptr, Nd4jLong numBytes) {
|
||||
#if defined(__GNUC__) && !defined(__MINGW64__) && !defined(SD_ANDROID_BUILD) && !defined(SD_IOS_BUILD) && !defined(SD_APPLE_BUILD)
|
||||
if (Environment::getInstance()->isDetectingLeaks()) {
|
||||
if (Environment::getInstance().isDetectingLeaks()) {
|
||||
auto lptr = reinterpret_cast<Nd4jLong>(ptr);
|
||||
|
||||
_locker.lock();
|
||||
|
@ -133,7 +131,7 @@ namespace sd {
|
|||
|
||||
void MemoryTracker::countOut(Nd4jPointer ptr) {
|
||||
#if defined(__GNUC__) && !defined(__MINGW64__) && !defined(SD_ANDROID_BUILD) && !defined(SD_IOS_BUILD) && !defined(SD_APPLE_BUILD)
|
||||
if (Environment::getInstance()->isDetectingLeaks()) {
|
||||
if (Environment::getInstance().isDetectingLeaks()) {
|
||||
auto lptr = reinterpret_cast<Nd4jLong>(ptr);
|
||||
|
||||
_locker.lock();
|
||||
|
@ -172,7 +170,5 @@ namespace sd {
|
|||
_allocations.clear();
|
||||
_released.clear();
|
||||
}
|
||||
|
||||
MemoryTracker* MemoryTracker::_INSTANCE = 0;
|
||||
}
|
||||
}
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue