340 lines
12 KiB
C++
340 lines
12 KiB
C++
|
/*******************************************************************************
|
||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||
|
*
|
||
|
* This program and the accompanying materials are made available under the
|
||
|
* terms of the Apache License, Version 2.0 which is available at
|
||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||
|
* License for the specific language governing permissions and limitations
|
||
|
* under the License.
|
||
|
*
|
||
|
* SPDX-License-Identifier: Apache-2.0
|
||
|
******************************************************************************/
|
||
|
|
||
|
//
|
||
|
// @author raver119@gmail.com
|
||
|
//
|
||
|
|
||
|
#include <helpers/EnumUtils.h>
|
||
|
#include <graph/Variable.h>
|
||
|
#include <array/DataTypeUtils.h>
|
||
|
#include <array/ByteOrderUtils.h>
|
||
|
#include <array/DataTypeConversions.h>
|
||
|
#include <graph/FlatUtils.h>
|
||
|
#include <helpers/StringUtils.h>
|
||
|
|
||
|
namespace nd4j {
|
||
|
namespace graph {
|
||
|
|
||
|
template <typename N>
|
||
|
Variable* Variable::asT() {
|
||
|
auto result = new Variable(this->isPlaceholder());
|
||
|
|
||
|
result->markExternal(this->_external);
|
||
|
result->setId(this->_id);
|
||
|
result->markReadOnly(this->_readOnly);
|
||
|
result->setName(&this->_name);
|
||
|
result->setIndex(this->_index);
|
||
|
|
||
|
if (this->_ndarray != nullptr)
|
||
|
result->setNDArray(this->_ndarray->template asT<N>());
|
||
|
|
||
|
// FIXME: add support for ArrayList
|
||
|
if (this->_list != nullptr) {
|
||
|
nd4j_printf("ArrayList not supported yet\n", "");
|
||
|
throw std::runtime_error("ArrayList not supported yet for asT");
|
||
|
}
|
||
|
|
||
|
return result;
|
||
|
}
|
||
|
BUILD_SINGLE_TEMPLATE(template Variable* Variable::asT, (), LIBND4J_TYPES);
|
||
|
|
||
|
nd4j::graph::Variable* nd4j::graph::Variable::clone() {
|
||
|
auto result = new Variable(this->isPlaceholder());
|
||
|
result->_external = this->_external;
|
||
|
result->_id = this->_id;
|
||
|
result->_readOnly = this->_readOnly;
|
||
|
result->_name = this->_name;
|
||
|
result->_index = this->_index;
|
||
|
|
||
|
if (this->_ndarray != nullptr)
|
||
|
result->_ndarray = this->_ndarray->dup(this->_ndarray->ordering());
|
||
|
|
||
|
if (this->_list != nullptr)
|
||
|
result->_list = this->_list->clone();
|
||
|
|
||
|
return result;
|
||
|
}
|
||
|
|
||
|
void nd4j::graph::Variable::setIndex(int index) {
|
||
|
_index = index;
|
||
|
}
|
||
|
|
||
|
bool nd4j::graph::Variable::hasNDArray() {
|
||
|
return _ndarray != nullptr;
|
||
|
}
|
||
|
|
||
|
void nd4j::graph::Variable::setVariableType(VariableType variableType) {
|
||
|
_variableType = variableType;
|
||
|
}
|
||
|
|
||
|
bool nd4j::graph::Variable::hasNDArrayList() {
|
||
|
return _list != nullptr;
|
||
|
}
|
||
|
|
||
|
bool nd4j::graph::Variable::isPlaceholder() {
|
||
|
return _placeholder;
|
||
|
}
|
||
|
|
||
|
std::string * nd4j::graph::Variable::getName() {
|
||
|
return &_name;
|
||
|
}
|
||
|
|
||
|
void nd4j::graph::Variable::setName(std::string *name) {
|
||
|
_name = *name;
|
||
|
}
|
||
|
|
||
|
int nd4j::graph::Variable::id() {
|
||
|
return _id;
|
||
|
}
|
||
|
|
||
|
int nd4j::graph::Variable::index() {
|
||
|
return _index;
|
||
|
}
|
||
|
|
||
|
void nd4j::graph::Variable::setId(int id) {
|
||
|
_id = id;
|
||
|
}
|
||
|
|
||
|
bool nd4j::graph::Variable::isEmpty() {
|
||
|
if (_variableType == VariableType::NDARRAY)
|
||
|
return _ndarray == nullptr || !_ndarray->nonNull();
|
||
|
else if (_variableType == VariableType::ARRAY_LIST)
|
||
|
return _list == nullptr;
|
||
|
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
bool nd4j::graph::Variable::isExternal() {
|
||
|
return _external;
|
||
|
}
|
||
|
|
||
|
bool nd4j::graph::Variable::isReadOnly() {
|
||
|
return _readOnly;
|
||
|
}
|
||
|
|
||
|
void nd4j::graph::Variable::markExternal(bool reallyExternal) {
|
||
|
this->_external = reallyExternal;
|
||
|
}
|
||
|
|
||
|
void nd4j::graph::Variable::markRemovable(bool reallyRemovable) {
|
||
|
if (!reallyRemovable)
|
||
|
nd4j_debug("","");
|
||
|
this->_removable = reallyRemovable;
|
||
|
}
|
||
|
|
||
|
void nd4j::graph::Variable::markReadOnly(bool reallyReadOnly) {
|
||
|
this->_readOnly = reallyReadOnly;
|
||
|
}
|
||
|
|
||
|
nd4j::NDArray * nd4j::graph::Variable::getNDArray() {
|
||
|
if (_variableType != VariableType::NDARRAY) {
|
||
|
nd4j_printf("Variable[%i:%i/<%s>] is has [%s] type, but NDArray was requested\n", this->_id, this->_index, this->_name.c_str(), EnumUtils::_VariableTypeToString(_variableType));
|
||
|
}
|
||
|
|
||
|
if (this->_ndarray == nullptr) {
|
||
|
if (_name.empty()) {
|
||
|
auto nodeId = StringUtils::valueToString<int>(this->id());
|
||
|
auto outputIndex = StringUtils::valueToString<int>(this->index());
|
||
|
throw std::runtime_error("Array doesn't exist for Variable <" + nodeId + ":" + outputIndex + ">");
|
||
|
} else {
|
||
|
auto outputIndex = StringUtils::valueToString<int>(this->index());
|
||
|
throw std::runtime_error("Array doesn't exist for Variable <" + this->_name + ":" + outputIndex+ ">");
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return this->_ndarray;
|
||
|
}
|
||
|
|
||
|
nd4j::NDArrayList * nd4j::graph::Variable::getNDArrayList() {
|
||
|
if (_variableType != VariableType::ARRAY_LIST) {
|
||
|
nd4j_debug("Variable[%i:%i/<%s>] is has [%s] type, but NDArrayList was requested\n", this->_id, this->_index, this->_name.c_str(), EnumUtils::_VariableTypeToString(_variableType));
|
||
|
}
|
||
|
return this->_list;
|
||
|
}
|
||
|
|
||
|
|
||
|
bool Variable::isRemovable() {
|
||
|
return _removable;
|
||
|
}
|
||
|
|
||
|
|
||
|
void nd4j::graph::Variable::setNDArrayList(nd4j::NDArrayList * list) {
|
||
|
this->_variableType = VariableType::ARRAY_LIST;
|
||
|
this->_list = list;
|
||
|
}
|
||
|
|
||
|
|
||
|
void nd4j::graph::Variable::setNDArray(nd4j::NDArray * array) {
|
||
|
this->_variableType = VariableType::NDARRAY;
|
||
|
this->_ndarray = array;
|
||
|
}
|
||
|
|
||
|
|
||
|
VariableType nd4j::graph::Variable::variableType() {
|
||
|
return _variableType;
|
||
|
}
|
||
|
|
||
|
|
||
|
nd4j::graph::Variable::Variable(const nd4j::graph::FlatVariable *flatVariable) {
|
||
|
auto vid = flatVariable->id();
|
||
|
this->_id = vid->first();
|
||
|
this->_index = vid->second();
|
||
|
|
||
|
if (flatVariable->name() != nullptr && flatVariable->name()->size() != 0)
|
||
|
this->_name = flatVariable->name()->str();
|
||
|
|
||
|
_external = true;
|
||
|
_readOnly = false;
|
||
|
|
||
|
int8_t *buffer = nullptr;
|
||
|
|
||
|
switch (flatVariable->variabletype()) {
|
||
|
case VarType_VARIABLE: {
|
||
|
|
||
|
// ?????
|
||
|
if (flatVariable->ndarray() != nullptr) {
|
||
|
auto ar = flatVariable->ndarray();
|
||
|
_ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar);
|
||
|
}
|
||
|
|
||
|
_variableType = VariableType::NDARRAY;
|
||
|
}
|
||
|
break;
|
||
|
case VarType_CONSTANT: {
|
||
|
if (flatVariable->ndarray() == nullptr)
|
||
|
throw std::runtime_error("CONSTANT variable must have NDArray bundled");
|
||
|
|
||
|
auto ar = flatVariable->ndarray();
|
||
|
if (ar->dtype() == DataType_UTF8) {
|
||
|
_ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar);
|
||
|
} else {
|
||
|
_ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar);
|
||
|
}
|
||
|
|
||
|
_variableType = VariableType::NDARRAY;
|
||
|
}
|
||
|
break;
|
||
|
case VarType_ARRAY: {
|
||
|
|
||
|
// ?????
|
||
|
if (flatVariable->ndarray() != nullptr) {
|
||
|
auto ar = flatVariable->ndarray();
|
||
|
_ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar);
|
||
|
// _ndarray->triggerAllocationFlag(true);
|
||
|
}
|
||
|
|
||
|
_variableType = VariableType::NDARRAY;
|
||
|
}
|
||
|
break;
|
||
|
case VarType_PLACEHOLDER: {
|
||
|
if (flatVariable->shape() == nullptr && flatVariable->ndarray() == nullptr)
|
||
|
throw std::runtime_error("PLACEHOLDER variable must have shape defined");
|
||
|
|
||
|
if (flatVariable->ndarray() != nullptr) {
|
||
|
auto ar = flatVariable->ndarray();
|
||
|
_ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar);
|
||
|
// _ndarray->triggerAllocationFlag(true);
|
||
|
|
||
|
_variableType = VariableType::NDARRAY;
|
||
|
}
|
||
|
|
||
|
if (flatVariable->shape() != nullptr) {
|
||
|
int shapeLen = flatVariable->shape()->Length();
|
||
|
for (int i = 0; i < flatVariable->shape()->size(); i++)
|
||
|
_shape.emplace_back(flatVariable->shape()->Get(i));
|
||
|
|
||
|
if (_ndarray == nullptr)
|
||
|
_variableType = VariableType::PLACEHOLDER;
|
||
|
}
|
||
|
}
|
||
|
break;
|
||
|
default:
|
||
|
throw std::runtime_error("Unknown variable type used");
|
||
|
}
|
||
|
}
|
||
|
|
||
|
std::vector<Nd4jLong>& nd4j::graph::Variable::shape() {
|
||
|
return _shape;
|
||
|
}
|
||
|
|
||
|
nd4j::graph::Variable::Variable(bool placeholder) {
|
||
|
_placeholder = placeholder;
|
||
|
}
|
||
|
|
||
|
|
||
|
nd4j::graph::Variable::Variable(NDArray *array, const char *name ) {
|
||
|
_ndarray = array;
|
||
|
|
||
|
_external = false;
|
||
|
_readOnly = false;
|
||
|
|
||
|
if (name != nullptr)
|
||
|
_name = std::string(name);
|
||
|
|
||
|
if (_ndarray != nullptr)
|
||
|
_variableType = VariableType::NDARRAY;
|
||
|
}
|
||
|
|
||
|
|
||
|
nd4j::graph::Variable::Variable(NDArray *array, const char *name, int id, int idx) : Variable(array, name) {
|
||
|
_id = id;
|
||
|
_index = idx;
|
||
|
}
|
||
|
|
||
|
|
||
|
nd4j::graph::Variable::~Variable() {
|
||
|
//nd4j_printf("Removing variable [%i:%i]\n", _id, _index);
|
||
|
if (_variableType == VariableType::NDARRAY) {
|
||
|
nd4j_debug("Removing variable <%i:%i>\n", _id, _index);
|
||
|
if (_ndarray != nullptr && _removable && !_readOnly)
|
||
|
delete _ndarray;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
|
||
|
void Variable::setId(int id, int idx) {
|
||
|
_id = id;
|
||
|
_index = idx;
|
||
|
}
|
||
|
|
||
|
|
||
|
flatbuffers::Offset<FlatVariable> Variable::asFlatVariable(flatbuffers::FlatBufferBuilder &builder) {
|
||
|
if (this->hasNDArray()) {
|
||
|
auto array = this->getNDArray();
|
||
|
auto fShape = builder.CreateVector(array->getShapeInfoAsFlatVector());
|
||
|
|
||
|
auto fBuffer = builder.CreateVector(array->asByteVector());
|
||
|
|
||
|
// packing array
|
||
|
auto fArray = CreateFlatArray(builder, fShape, fBuffer, (nd4j::graph::DataType) array->dataType());
|
||
|
|
||
|
// packing id/index of this var
|
||
|
auto fVid = CreateIntPair(builder, this->_id, this->_index);
|
||
|
|
||
|
// name is still optional
|
||
|
flatbuffers::Offset<flatbuffers::String> stringId = 0;
|
||
|
if (!this->_name.empty())
|
||
|
stringId = builder.CreateString(this->_name);
|
||
|
|
||
|
// returning array
|
||
|
return CreateFlatVariable(builder, fVid, stringId, static_cast<nd4j::graph::DataType>(array->dataType()), 0, fArray);
|
||
|
} else {
|
||
|
throw std::runtime_error("Variable::asFlatVariable isn't possible for NDArrayList");
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|