cavis/libnd4j/include/graph/impl/Variable.cpp

360 lines
12 KiB
C++
Raw Normal View History

2019-06-06 14:21:15 +02:00
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <helpers/EnumUtils.h>
#include <graph/Variable.h>
#include <array/DataTypeUtils.h>
#include <array/ByteOrderUtils.h>
#include <array/DataTypeConversions.h>
#include <graph/FlatUtils.h>
#include <helpers/StringUtils.h>
namespace sd {
2019-06-06 14:21:15 +02:00
namespace graph {
template <typename N>
Variable* Variable::asT() {
auto result = new Variable(this->isPlaceholder());
result->markExternal(this->_external);
result->setId(this->_id);
result->markReadOnly(this->_readOnly);
result->setName(&this->_name);
result->setIndex(this->_index);
if (this->_ndarray != nullptr)
Shyrma temp (#131) * - specifying template instantiation for certain types in float16 and bloat16 Signed-off-by: Yurii <iuriish@yahoo.com> * - polishing bfloat16 and float16 member functions template specialization Signed-off-by: Yurii <iuriish@yahoo.com> * - rewrite and overload array +-*/ scalar and scalar +-*/ arr in NDAray class Signed-off-by: Yurii <iuriish@yahoo.com> * - make corrections which have to do with and rvalue lvalue conversions Signed-off-by: Yurii <iuriish@yahoo.com> * - provide move semantic in NDArray operators array +-/* array Signed-off-by: Yurii <iuriish@yahoo.com> * float16/bfloat16 tweaks Signed-off-by: raver119 <raver119@gmail.com> * one more tweak Signed-off-by: raver119 <raver119@gmail.com> * - make float16 and bfloat16 to compile successfully on cuda Signed-off-by: Yurii <iuriish@yahoo.com> * - do not use resources of view-like arrays when move semantics is applied Signed-off-by: Yurii <iuriish@yahoo.com> * - get rid of pointers in signatures NDArray methods 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - correction of signature of NDArray::dup method Signed-off-by: Yurii <iuriish@yahoo.com> * - correction of signature of NDArray::reduceAlongDimension method Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyIndexReduce and applyTrueBroadcast methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyReduce3 and varianceAlongDimension methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::tensorsAlongDimension and diagonal methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::allTensorsAlongDimension Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::reduceAlongDimension 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyTransform 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyPairwiseTransform 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyBroadcast 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyTrueBroadcast 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyScalar and applyScalarArr Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::lambda methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::reduce3 methods 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of following NDArray methods: add/sub/mul/div row/column and fillAsTriangular Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::tileToShape methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::isShapeSameStrict method Signed-off-by: Yurii <iuriish@yahoo.com> * minor corrections in tests Signed-off-by: Yurii <iuriish@yahoo.com> * - replace reduce op in batchnorm mkldnn Signed-off-by: Yurii <iuriish@yahoo.com> * - add explicit templates instantiations for operator+(NDArray&&. const scalar) Signed-off-by: Yurii <iuriish@yahoo.com> * - corrections of casts in float16/bfloat16 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide move semantics in following NDArray methods: transform, applyTrueBroadcast, transpose, reshape, permute Signed-off-by: Yurii <iuriish@yahoo.com> * - get rid of input array A duplicate in svd cuda op Signed-off-by: Yurii <iuriish@yahoo.com> * - avoid available bug in svd cuda API Signed-off-by: Yurii <iuriish@yahoo.com> * - add temporary global memory buffer in svd cuda when calcUV = false and m != n Signed-off-by: Yurii <iuriish@yahoo.com> * - remove test with blfoat16 type for betainC Signed-off-by: Yurii <iuriish@yahoo.com> * - resolve conflicts after master has been merged in Signed-off-by: Yurii <iuriish@yahoo.com> * - changed type of affected input array in fused_batch_norm Signed-off-by: Yurii <iuriish@yahoo.com> * - add several explicit type castings Signed-off-by: Yurii <iuriish@yahoo.com> * - add ND4J_EXPORT to operators Signed-off-by: Yurii <iuriish@yahoo.com> * - add explicit template types in instantiations of template arithm operators of NDArray class Signed-off-by: Yurii <iuriish@yahoo.com> * - one more test fix Signed-off-by: Yurii <iuriish@yahoo.com> Co-authored-by: raver119 <raver119@gmail.com>
2019-12-20 20:35:39 +01:00
result->setNDArray(new NDArray(this->_ndarray->template asT<N>()));
2019-06-06 14:21:15 +02:00
// FIXME: add support for ArrayList
if (this->_list != nullptr) {
nd4j_printf("ArrayList not supported yet\n", "");
throw std::runtime_error("ArrayList not supported yet for asT");
}
return result;
}
BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT Variable* Variable::asT, (), LIBND4J_TYPES);
2019-06-06 14:21:15 +02:00
sd::graph::Variable* sd::graph::Variable::clone() {
2019-06-06 14:21:15 +02:00
auto result = new Variable(this->isPlaceholder());
result->_external = this->_external;
result->_id = this->_id;
result->_readOnly = this->_readOnly;
result->_name = this->_name;
result->_index = this->_index;
if (this->_ndarray != nullptr) {
Shyrma temp (#131) * - specifying template instantiation for certain types in float16 and bloat16 Signed-off-by: Yurii <iuriish@yahoo.com> * - polishing bfloat16 and float16 member functions template specialization Signed-off-by: Yurii <iuriish@yahoo.com> * - rewrite and overload array +-*/ scalar and scalar +-*/ arr in NDAray class Signed-off-by: Yurii <iuriish@yahoo.com> * - make corrections which have to do with and rvalue lvalue conversions Signed-off-by: Yurii <iuriish@yahoo.com> * - provide move semantic in NDArray operators array +-/* array Signed-off-by: Yurii <iuriish@yahoo.com> * float16/bfloat16 tweaks Signed-off-by: raver119 <raver119@gmail.com> * one more tweak Signed-off-by: raver119 <raver119@gmail.com> * - make float16 and bfloat16 to compile successfully on cuda Signed-off-by: Yurii <iuriish@yahoo.com> * - do not use resources of view-like arrays when move semantics is applied Signed-off-by: Yurii <iuriish@yahoo.com> * - get rid of pointers in signatures NDArray methods 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - correction of signature of NDArray::dup method Signed-off-by: Yurii <iuriish@yahoo.com> * - correction of signature of NDArray::reduceAlongDimension method Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyIndexReduce and applyTrueBroadcast methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyReduce3 and varianceAlongDimension methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::tensorsAlongDimension and diagonal methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::allTensorsAlongDimension Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::reduceAlongDimension 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyTransform 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyPairwiseTransform 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyBroadcast 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyTrueBroadcast 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyScalar and applyScalarArr Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::lambda methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::reduce3 methods 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of following NDArray methods: add/sub/mul/div row/column and fillAsTriangular Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::tileToShape methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::isShapeSameStrict method Signed-off-by: Yurii <iuriish@yahoo.com> * minor corrections in tests Signed-off-by: Yurii <iuriish@yahoo.com> * - replace reduce op in batchnorm mkldnn Signed-off-by: Yurii <iuriish@yahoo.com> * - add explicit templates instantiations for operator+(NDArray&&. const scalar) Signed-off-by: Yurii <iuriish@yahoo.com> * - corrections of casts in float16/bfloat16 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide move semantics in following NDArray methods: transform, applyTrueBroadcast, transpose, reshape, permute Signed-off-by: Yurii <iuriish@yahoo.com> * - get rid of input array A duplicate in svd cuda op Signed-off-by: Yurii <iuriish@yahoo.com> * - avoid available bug in svd cuda API Signed-off-by: Yurii <iuriish@yahoo.com> * - add temporary global memory buffer in svd cuda when calcUV = false and m != n Signed-off-by: Yurii <iuriish@yahoo.com> * - remove test with blfoat16 type for betainC Signed-off-by: Yurii <iuriish@yahoo.com> * - resolve conflicts after master has been merged in Signed-off-by: Yurii <iuriish@yahoo.com> * - changed type of affected input array in fused_batch_norm Signed-off-by: Yurii <iuriish@yahoo.com> * - add several explicit type castings Signed-off-by: Yurii <iuriish@yahoo.com> * - add ND4J_EXPORT to operators Signed-off-by: Yurii <iuriish@yahoo.com> * - add explicit template types in instantiations of template arithm operators of NDArray class Signed-off-by: Yurii <iuriish@yahoo.com> * - one more test fix Signed-off-by: Yurii <iuriish@yahoo.com> Co-authored-by: raver119 <raver119@gmail.com>
2019-12-20 20:35:39 +01:00
result->_ndarray = new NDArray(this->_ndarray->dup(this->_ndarray->ordering()));
result->_readOnly = false;
result->_removable = true;
}
2019-06-06 14:21:15 +02:00
if (this->_list != nullptr)
result->_list = this->_list->clone();
return result;
}
void sd::graph::Variable::setIndex(int index) {
2019-06-06 14:21:15 +02:00
_index = index;
}
bool sd::graph::Variable::hasNDArray() {
2019-06-06 14:21:15 +02:00
return _ndarray != nullptr;
}
void sd::graph::Variable::setVariableType(VariableType variableType) {
2019-06-06 14:21:15 +02:00
_variableType = variableType;
}
bool sd::graph::Variable::hasNDArrayList() {
2019-06-06 14:21:15 +02:00
return _list != nullptr;
}
bool sd::graph::Variable::isPlaceholder() {
2019-06-06 14:21:15 +02:00
return _placeholder;
}
std::string * sd::graph::Variable::getName() {
2019-06-06 14:21:15 +02:00
return &_name;
}
void sd::graph::Variable::setName(std::string *name) {
2019-06-06 14:21:15 +02:00
_name = *name;
}
int sd::graph::Variable::id() {
2019-06-06 14:21:15 +02:00
return _id;
}
int sd::graph::Variable::index() {
2019-06-06 14:21:15 +02:00
return _index;
}
void sd::graph::Variable::setId(int id) {
2019-06-06 14:21:15 +02:00
_id = id;
}
bool sd::graph::Variable::isEmpty() {
2019-06-06 14:21:15 +02:00
if (_variableType == VariableType::NDARRAY)
return _ndarray == nullptr || !_ndarray->nonNull();
else if (_variableType == VariableType::ARRAY_LIST)
return _list == nullptr;
return false;
}
bool sd::graph::Variable::isExternal() {
2019-06-06 14:21:15 +02:00
return _external;
}
bool sd::graph::Variable::isReadOnly() {
2019-06-06 14:21:15 +02:00
return _readOnly;
}
void sd::graph::Variable::markExternal(bool reallyExternal) {
2019-06-06 14:21:15 +02:00
this->_external = reallyExternal;
}
void sd::graph::Variable::markRemovable(bool reallyRemovable) {
2019-06-06 14:21:15 +02:00
if (!reallyRemovable)
nd4j_debug("","");
this->_removable = reallyRemovable;
}
void sd::graph::Variable::markReadOnly(bool reallyReadOnly) {
2019-06-06 14:21:15 +02:00
this->_readOnly = reallyReadOnly;
}
sd::NDArray * sd::graph::Variable::getNDArray() {
2019-06-06 14:21:15 +02:00
if (_variableType != VariableType::NDARRAY) {
nd4j_printf("Variable[%i:%i/<%s>] is has [%s] type, but NDArray was requested\n", this->_id, this->_index, this->_name.c_str(), EnumUtils::_VariableTypeToString(_variableType));
}
if (this->_ndarray == nullptr) {
if (_name.empty()) {
auto nodeId = StringUtils::valueToString<int>(this->id());
auto outputIndex = StringUtils::valueToString<int>(this->index());
throw std::runtime_error("Array doesn't exist for Variable <" + nodeId + ":" + outputIndex + ">");
} else {
auto outputIndex = StringUtils::valueToString<int>(this->index());
throw std::runtime_error("Array doesn't exist for Variable <" + this->_name + ":" + outputIndex+ ">");
}
}
return this->_ndarray;
}
sd::NDArrayList * sd::graph::Variable::getNDArrayList() {
2019-06-06 14:21:15 +02:00
if (_variableType != VariableType::ARRAY_LIST) {
nd4j_debug("Variable[%i:%i/<%s>] is has [%s] type, but NDArrayList was requested\n", this->_id, this->_index, this->_name.c_str(), EnumUtils::_VariableTypeToString(_variableType));
}
return this->_list;
}
bool Variable::isRemovable() {
return _removable;
}
void sd::graph::Variable::setNDArrayList(sd::NDArrayList * list) {
2019-06-06 14:21:15 +02:00
this->_variableType = VariableType::ARRAY_LIST;
this->_list = list;
}
void sd::graph::Variable::setNDArray(sd::NDArray * array) {
2019-06-06 14:21:15 +02:00
this->_variableType = VariableType::NDARRAY;
this->_ndarray = array;
}
VariableType sd::graph::Variable::variableType() {
2019-06-06 14:21:15 +02:00
return _variableType;
}
sd::graph::Variable::Variable(const sd::graph::FlatVariable *flatVariable) {
2019-06-06 14:21:15 +02:00
auto vid = flatVariable->id();
this->_id = vid->first();
this->_index = vid->second();
if (flatVariable->name() != nullptr && flatVariable->name()->size() != 0)
this->_name = flatVariable->name()->str();
_external = true;
_readOnly = false;
int8_t *buffer = nullptr;
switch (flatVariable->variabletype()) {
case VarType_VARIABLE: {
// ?????
if (flatVariable->ndarray() != nullptr) {
auto ar = flatVariable->ndarray();
_ndarray = sd::graph::FlatUtils::fromFlatArray(ar);
2019-06-06 14:21:15 +02:00
}
_variableType = VariableType::NDARRAY;
}
break;
case VarType_CONSTANT: {
if (flatVariable->ndarray() == nullptr)
throw std::runtime_error("CONSTANT variable must have NDArray bundled");
auto ar = flatVariable->ndarray();
if (ar->dtype() == DType_UTF8) {
_ndarray = sd::graph::FlatUtils::fromFlatArray(ar);
2019-06-06 14:21:15 +02:00
} else {
_ndarray = sd::graph::FlatUtils::fromFlatArray(ar);
2019-06-06 14:21:15 +02:00
}
_variableType = VariableType::NDARRAY;
}
break;
case VarType_ARRAY: {
// ?????
if (flatVariable->ndarray() != nullptr) {
auto ar = flatVariable->ndarray();
_ndarray = sd::graph::FlatUtils::fromFlatArray(ar);
2019-06-06 14:21:15 +02:00
// _ndarray->triggerAllocationFlag(true);
}
_variableType = VariableType::NDARRAY;
}
break;
case VarType_PLACEHOLDER: {
if (flatVariable->shape() == nullptr && flatVariable->ndarray() == nullptr)
throw std::runtime_error("PLACEHOLDER variable must have shape defined");
if (flatVariable->ndarray() != nullptr) {
auto ar = flatVariable->ndarray();
_ndarray = sd::graph::FlatUtils::fromFlatArray(ar);
2019-06-06 14:21:15 +02:00
// _ndarray->triggerAllocationFlag(true);
_variableType = VariableType::NDARRAY;
}
if (flatVariable->shape() != nullptr) {
int shapeLen = flatVariable->shape()->Length();
for (int i = 0; i < flatVariable->shape()->size(); i++)
_shape.emplace_back(flatVariable->shape()->Get(i));
if (_ndarray == nullptr)
_variableType = VariableType::PLACEHOLDER;
}
}
break;
default:
throw std::runtime_error("Unknown variable type used");
}
}
std::vector<Nd4jLong>& sd::graph::Variable::shape() {
2019-06-06 14:21:15 +02:00
return _shape;
}
sd::graph::Variable::Variable(bool placeholder) {
2019-06-06 14:21:15 +02:00
_placeholder = placeholder;
}
sd::graph::Variable::Variable(NDArray *array, const char *name ) {
2019-06-06 14:21:15 +02:00
_ndarray = array;
_external = false;
_readOnly = false;
if (name != nullptr)
_name = std::string(name);
if (_ndarray != nullptr)
_variableType = VariableType::NDARRAY;
}
sd::graph::Variable::Variable(NDArray *array, const char *name, int id, int idx) : Variable(array, name) {
2019-06-06 14:21:15 +02:00
_id = id;
_index = idx;
}
sd::graph::Variable::~Variable() {
2019-06-06 14:21:15 +02:00
//nd4j_printf("Removing variable [%i:%i]\n", _id, _index);
if (_variableType == VariableType::NDARRAY) {
nd4j_debug("Removing variable <%i:%i>\n", _id, _index);
if (_ndarray != nullptr && _removable && !_readOnly)
delete _ndarray;
}
}
void Variable::setId(int id, int idx) {
_id = id;
_index = idx;
}
flatbuffers::Offset<FlatVariable> Variable::asFlatVariable(flatbuffers::FlatBufferBuilder &builder) {
if (this->hasNDArray()) {
auto array = this->getNDArray();
auto fShape = builder.CreateVector(array->getShapeInfoAsFlatVector());
auto fBuffer = builder.CreateVector(array->asByteVector());
// packing array
auto fArray = CreateFlatArray(builder, fShape, fBuffer, (sd::graph::DType) array->dataType());
2019-06-06 14:21:15 +02:00
// packing id/index of this var
auto fVid = CreateIntPair(builder, this->_id, this->_index);
// name is still optional
flatbuffers::Offset<flatbuffers::String> stringId = 0;
if (!this->_name.empty())
stringId = builder.CreateString(this->_name);
// returning array
return CreateFlatVariable(builder, fVid, stringId, static_cast<sd::graph::DType>(array->dataType()), 0, fArray);
2019-06-06 14:21:15 +02:00
} else {
throw std::runtime_error("Variable::asFlatVariable isn't possible for NDArrayList");
}
}
}
}
namespace std {
size_t hash<std::pair<int, int>>::operator()(const std::pair<int,int>& k) const {
auto v = std::hash<int>()(k.first);
v ^= std::hash<int>()(k.second) + 0x9e3779b9 + (v << 6) + (v >> 2);
return v;
}
size_t hash<bfloat16>::operator()(const bfloat16& k) const {
return std::hash<float>()((float)k);
}
size_t hash<float16>::operator()(const float16& k) const {
return std::hash<float>()((float)k);
}
2019-06-06 14:21:15 +02:00
}