2019-06-06 14:21:15 +02:00
|
|
|
/*******************************************************************************
|
|
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
|
|
*
|
|
|
|
* This program and the accompanying materials are made available under the
|
|
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
|
|
*
|
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
|
|
* License for the specific language governing permissions and limitations
|
|
|
|
* under the License.
|
|
|
|
*
|
|
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
|
|
******************************************************************************/
|
|
|
|
|
|
|
|
//
|
|
|
|
// @author raver119@gmail.com
|
|
|
|
//
|
|
|
|
|
|
|
|
#include <graph/Node.h>
|
|
|
|
#include <ops/declarable/OpRegistrator.h>
|
|
|
|
#include <ops/declarable/LegacyTransformSameOp.h>
|
|
|
|
#include <ops/declarable/LegacyTransformFloatOp.h>
|
|
|
|
#include <ops/declarable/LegacyScalarOp.h>
|
|
|
|
#include <ops/declarable/LegacyReduceSameOp.h>
|
|
|
|
#include <ops/declarable/LegacyReduceFloatOp.h>
|
|
|
|
#include <ops/declarable/LegacyIndexReduceOp.h>
|
|
|
|
#include <ops/declarable/LegacyStatsOp.h>
|
|
|
|
#include <ops/declarable/LegacyBroadcastOp.h>
|
|
|
|
#include <ops/declarable/LegacyReduce3Op.h>
|
|
|
|
#include <ops/declarable/LegacyPairwiseTransformOp.h>
|
|
|
|
#include <ops/declarable/LegacyRandomOp.h>
|
|
|
|
#include <ops/declarable/LegacyOp.h>
|
|
|
|
#include <ops/declarable/LegacyReduceLongOp.h>
|
|
|
|
#include <ops/declarable/LegacyReduceBoolOp.h>
|
|
|
|
#include <ops/declarable/LegacyBroadcastBoolOp.h>
|
|
|
|
#include <ops/declarable/LegacyScalarBoolOp.h>
|
|
|
|
#include <ops/declarable/LegacyPairwiseTransformBoolOp.h>
|
|
|
|
#include <ops/declarable/LegacyTransformStrictOp.h>
|
|
|
|
#include <ops/declarable/LegacyTransformBoolOp.h>
|
|
|
|
#include <graph/FlatUtils.h>
|
2020-03-02 10:49:41 +01:00
|
|
|
#include <array/NDArrayFactory.h>
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
namespace sd {
|
2019-06-06 14:21:15 +02:00
|
|
|
namespace graph {
|
2020-03-02 10:49:41 +01:00
|
|
|
void sd::graph::Node::setOuterTime(Nd4jLong time){
|
2019-06-06 14:21:15 +02:00
|
|
|
// if (hasBlockAttached())
|
|
|
|
// _block->setOuterTime(time);
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void sd::graph::Node::setInnerTime(Nd4jLong time){
|
2019-06-06 14:21:15 +02:00
|
|
|
// if (hasBlockAttached())
|
|
|
|
// _block->setInnerTime(time);
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void sd::graph::Node::setGraph(sd::graph::Graph* graph) {
|
2019-06-06 14:21:15 +02:00
|
|
|
_graph = graph;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::graph::Graph* sd::graph::Node::getGraph() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return _graph;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
bool sd::graph::Node::hasGraphEmbedded() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return _graph != nullptr;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void sd::graph::Node::markInplace(bool reallyInplace) {
|
2019-06-06 14:21:15 +02:00
|
|
|
_isInplace = reallyInplace;
|
|
|
|
if (_protoContext != nullptr) {
|
|
|
|
_protoContext->markInplace(reallyInplace);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
OpClass sd::graph::Node::getOpClass() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return _opClass;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
bool sd::graph::Node::hasBlockAttached() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return _protoContext != nullptr;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
bool sd::graph::Node::isInplace() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return _isInplace;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
bool sd::graph::Node::isDivergencePoint() {
|
2019-06-06 14:21:15 +02:00
|
|
|
if (hasCustomOp()) {
|
|
|
|
return _customOp->getOpDescriptor()->isDivergent();
|
|
|
|
} else if (opType() == OpType_LOGIC && opNum() == 30)
|
|
|
|
return true;
|
|
|
|
else
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void sd::graph::Node::setActive(bool reallyActive) {
|
2019-06-06 14:21:15 +02:00
|
|
|
_active = reallyActive;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
bool sd::graph::Node::isActive() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return _active;
|
|
|
|
}
|
|
|
|
|
|
|
|
Nd4jLong Node::getFrameId() {
|
|
|
|
return _frameId;
|
|
|
|
}
|
|
|
|
|
|
|
|
void Node::setFrameId(Nd4jLong frameId) {
|
|
|
|
_frameId = frameId;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
ContextPrototype * sd::graph::Node::getContextPrototype() {
|
2019-06-06 14:21:15 +02:00
|
|
|
if (_protoContext == nullptr)
|
|
|
|
_protoContext = new ContextPrototype(this->getCustomOp() != nullptr ? this->getCustomOp()->getOpDescriptor() : nullptr, this->id());
|
|
|
|
if (_protoContext->inputs()->empty()) {
|
|
|
|
for (int e = 0; e < this->input()->size(); e++) {
|
|
|
|
_protoContext->inputs()->emplace_back(this->input()->at(e));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return _protoContext;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void sd::graph::Node::setContextPrototype(ContextPrototype *block) {
|
2019-06-06 14:21:15 +02:00
|
|
|
if (_protoContext != nullptr)
|
|
|
|
throw std::runtime_error("Block already exists");
|
|
|
|
|
|
|
|
_protoContext = block;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void sd::graph::Node::setId(int id) {
|
2019-06-06 14:21:15 +02:00
|
|
|
_id = id;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::DeclarableOp* sd::graph::Node::getCustomOp() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return _customOp;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void sd::graph::Node::setCustomOp(sd::ops::DeclarableOp *customOp) {
|
2019-06-06 14:21:15 +02:00
|
|
|
_customOp = customOp;
|
|
|
|
|
|
|
|
// divergent ops (Switch etc) are always inplace, they don't allocate anything
|
|
|
|
if (_customOp != nullptr && customOp->getOpDescriptor()->isDivergent())
|
|
|
|
_isInplace = true;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
bool sd::graph::Node::hasCustomOp() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return _customOp != nullptr;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
std::string * sd::graph::Node::name() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return this->getName();
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
std::string * sd::graph::Node::getName() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return &_name;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void sd::graph::Node::setName(const std::string& name) {
|
2019-06-06 14:21:15 +02:00
|
|
|
_name = name.c_str();
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void sd::graph::Node::setName(std::string *name) {
|
2019-06-06 14:21:15 +02:00
|
|
|
_name = *name;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
double sd::graph::Node::scalar() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return _scalar.e<double>(0);
|
|
|
|
};
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void sd::graph::Node::pickInput(std::pair<int,int>& pair) {
|
2019-06-06 14:21:15 +02:00
|
|
|
_input.push_back(pair);
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void sd::graph::Node::pickInput(int inputId, int outputId) {
|
2019-06-06 14:21:15 +02:00
|
|
|
std::pair<int,int> p(inputId,outputId);
|
|
|
|
pickInput(p);
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void sd::graph::Node::pickInput(int inputId) {
|
2019-06-06 14:21:15 +02:00
|
|
|
pickInput(inputId, 0);
|
|
|
|
|
|
|
|
if (inputId < 0)
|
|
|
|
_hasExternalInputs = true;
|
|
|
|
else
|
|
|
|
_hasInternalInputs = true;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void sd::graph::Node::pickExternalOutput(int outputId) {
|
2019-06-06 14:21:15 +02:00
|
|
|
std::pair<int, int> pair(outputId, 0);
|
|
|
|
_output.push_back(pair);
|
|
|
|
|
|
|
|
_hasExternalOutputs = true;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void sd::graph::Node::pickOutputOnce(int outputId) {
|
2019-06-06 14:21:15 +02:00
|
|
|
std::pair<int, int> pair(outputId, 0);
|
|
|
|
if (std::find(_output.begin(), _output.end(), pair) == _output.end())
|
|
|
|
pickOutput(outputId);
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void sd::graph::Node::pickOutput(int nodeId, int outputId) {
|
2019-06-06 14:21:15 +02:00
|
|
|
std::pair<int, int> pair(nodeId, outputId);
|
|
|
|
_output.emplace_back(pair);
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void sd::graph::Node::pickOutput(int outputId) {
|
2019-06-06 14:21:15 +02:00
|
|
|
std::pair<int, int> pair(outputId, 0);
|
|
|
|
_output.emplace_back(pair);
|
|
|
|
|
|
|
|
if (outputId < 0)
|
|
|
|
_hasExternalOutputs = true;
|
|
|
|
else
|
|
|
|
_hasInternalOutputs = true;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
int * sd::graph::Node::getDimensionsPtr() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return _dim;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
std::vector<int> * sd::graph::Node::getDimensions() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return &_dimensions;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
int sd::graph::Node::getLayer() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return _layer;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void sd::graph::Node::setLayer(int layer) {
|
2019-06-06 14:21:15 +02:00
|
|
|
_layer = layer;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
bool sd::graph::Node::hasExternalOutputs() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return _hasExternalOutputs;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
bool sd::graph::Node::hasExternalInputs() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return _hasExternalInputs;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
bool sd::graph::Node::hasInternalOutputs() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return _hasInternalOutputs;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
bool sd::graph::Node::hasInternalInputs() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return _hasInternalInputs;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
bool sd::graph::Node::isMultiInput() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return _input.size() > 1;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
bool sd::graph::Node::isMultiOutput() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return _output.size() > 1;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
double * sd::graph::Node::extraParams() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return _extraParams;
|
|
|
|
}
|
|
|
|
|
|
|
|
int Node::totalReferences() {
|
|
|
|
return _referencedBy.size();
|
|
|
|
}
|
|
|
|
|
|
|
|
void Node::addReference(int nodeId) {
|
|
|
|
_referencedBy.emplace_back(nodeId);
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::graph::OpType sd::graph::Node::opType() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return _opType;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
int sd::graph::Node::id() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return _id;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
Nd4jLong sd::graph::Node::opNum() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return _opNum;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
std::vector<std::pair<int,int>> *sd::graph::Node::input() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return &_input;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
std::vector<std::pair<int, int>> *sd::graph::Node::output() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return &_output;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool Node::isScoped() {
|
|
|
|
return _scope_id != 0;
|
|
|
|
}
|
|
|
|
|
|
|
|
void Node::setScopeInfo(int id, const char* name) {
|
|
|
|
_scope_id = id;
|
|
|
|
|
|
|
|
if (name != nullptr)
|
|
|
|
_scope_name = name;
|
|
|
|
}
|
|
|
|
|
|
|
|
int Node::scopeId() {
|
|
|
|
return _scope_id;
|
|
|
|
}
|
|
|
|
|
|
|
|
std::string* Node::scopeName() {
|
|
|
|
return &_scope_name;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
Node* Node::asT() {
|
|
|
|
auto node = this->clone();
|
|
|
|
node->_dataType = DataTypeUtils::fromT<T>();
|
|
|
|
return node;
|
|
|
|
}
|
2019-12-02 19:37:21 +01:00
|
|
|
BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT Node* Node::asT, (), LIBND4J_TYPES);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::graph::Node::Node(sd::ops::DeclarableOp *customOp, int id, std::initializer_list<int> input, std::initializer_list<int> output, std::initializer_list<int> dimensions, float scalar, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs) {
|
2019-06-06 14:21:15 +02:00
|
|
|
this->_opType = OpType_CUSTOM;
|
|
|
|
this->_id = id;
|
|
|
|
this->_opNum = customOp->getOpHash();
|
|
|
|
this->_extraParams = nullptr;
|
2020-03-02 10:49:41 +01:00
|
|
|
this->_dataType = sd::DataType::FLOAT32; // float as default
|
2019-06-06 14:21:15 +02:00
|
|
|
this->_dim = nullptr;
|
|
|
|
this->_customOp = customOp;
|
|
|
|
|
|
|
|
_hasExternalInputs = false;
|
|
|
|
_hasExternalOutputs = false;
|
|
|
|
_hasInternalInputs = false;
|
|
|
|
_hasInternalOutputs = false;
|
|
|
|
|
|
|
|
_scalar = NDArrayFactory::create(scalar);
|
|
|
|
|
|
|
|
for (auto i: input)
|
|
|
|
pickInput(i);
|
|
|
|
|
|
|
|
for (auto o: output)
|
|
|
|
pickOutput(o);
|
|
|
|
|
|
|
|
if (dimensions.size() > 0) {
|
|
|
|
_dim = new int[dimensions.size()];
|
|
|
|
int cnt = 0;
|
|
|
|
for (auto d: dimensions) {
|
|
|
|
_dimensions.push_back(d);
|
|
|
|
_dim[cnt++] = d;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
auto block = new ContextPrototype(this->getCustomOp()->getOpDescriptor(), this->id(), false);
|
|
|
|
|
|
|
|
for (auto v: dimensions)
|
|
|
|
block->getAxis()->emplace_back(v);
|
|
|
|
|
|
|
|
for (auto v: iArgs)
|
|
|
|
block->getIArguments()->emplace_back(v);
|
|
|
|
|
|
|
|
for (auto v: tArgs)
|
|
|
|
block->getTArguments()->emplace_back(v);
|
|
|
|
|
|
|
|
this->setContextPrototype(block);
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void sd::graph::Node::setOpType(OpType opType) {
|
2019-06-06 14:21:15 +02:00
|
|
|
this->_opType = opType;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::graph::Node::Node(OpType opType, int opNum, int id, std::initializer_list<int> input, std::initializer_list<int> output, std::initializer_list<int> dimensions, float scalar, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs) {
|
2019-06-06 14:21:15 +02:00
|
|
|
this->_opType = opType;
|
|
|
|
this->_id = id;
|
|
|
|
this->_opNum = opNum;
|
|
|
|
this->_extraParams = nullptr;
|
2020-03-02 10:49:41 +01:00
|
|
|
this->_dataType = sd::DataType::FLOAT32; // float as default
|
2019-06-06 14:21:15 +02:00
|
|
|
this->_dim = nullptr;
|
|
|
|
|
|
|
|
_hasExternalInputs = false;
|
|
|
|
_hasExternalOutputs = false;
|
|
|
|
_hasInternalInputs = false;
|
|
|
|
_hasInternalOutputs = false;
|
|
|
|
|
|
|
|
_scalar = NDArrayFactory::create(scalar);
|
|
|
|
|
|
|
|
for (auto i: input)
|
|
|
|
pickInput(i);
|
|
|
|
|
|
|
|
for (auto o: output)
|
|
|
|
pickOutput(o);
|
|
|
|
|
|
|
|
if (dimensions.size() > 0) {
|
|
|
|
_dim = new int[dimensions.size()];
|
|
|
|
int cnt = 0;
|
|
|
|
for (auto d: dimensions) {
|
|
|
|
_dimensions.push_back(d);
|
|
|
|
_dim[cnt++] = d;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// these ops allow in-place execution by design
|
|
|
|
if (opType == OpType_TRANSFORM_SAME || opType == OpType_TRANSFORM_FLOAT || opType == OpType_TRANSFORM_STRICT || opType == OpType_TRANSFORM_BOOL || opType == OpType_SCALAR || opType == OpType_BROADCAST) {
|
|
|
|
if (_output.size() <= 1) {
|
|
|
|
_isInplace = true;
|
|
|
|
}
|
|
|
|
_opClass = OpClass_TRANSFORM;
|
|
|
|
} else if (opType == OpType_REDUCE_SAME || opType == OpType_REDUCE_FLOAT || opType == OpType_REDUCE_BOOL || opType == OpType_REDUCE_LONG || opType == OpType_SUMMARYSTATS) {
|
|
|
|
_opClass = OpClass_REDUCTION;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if (opType == OpType_BROADCAST ||
|
|
|
|
opType == OpType_BROADCAST_BOOL ||
|
|
|
|
opType == OpType_INDEX_REDUCE ||
|
|
|
|
opType == OpType_SUMMARYSTATS ||
|
|
|
|
opType == OpType_REDUCE_BOOL ||
|
|
|
|
opType == OpType_REDUCE_SAME ||
|
|
|
|
opType == OpType_REDUCE_FLOAT ||
|
|
|
|
opType == OpType_REDUCE_3 ||
|
|
|
|
opType == OpType_TRANSFORM_STRICT ||
|
|
|
|
opType == OpType_TRANSFORM_SAME ||
|
|
|
|
opType == OpType_TRANSFORM_FLOAT ||
|
|
|
|
opType == OpType_TRANSFORM_BOOL ||
|
|
|
|
opType == OpType_RANDOM ||
|
|
|
|
opType == OpType_PAIRWISE ||
|
|
|
|
opType == OpType_PAIRWISE_BOOL ||
|
|
|
|
opType == OpType_SCALAR_BOOL ||
|
|
|
|
opType == OpType_SCALAR) {
|
|
|
|
|
|
|
|
this->_isDeductable = true;
|
|
|
|
|
|
|
|
auto block = new ContextPrototype(nullptr, this->id(), false);
|
|
|
|
|
|
|
|
for (auto v: dimensions)
|
|
|
|
block->getAxis()->emplace_back(v);
|
|
|
|
|
|
|
|
for (auto v: iArgs)
|
|
|
|
block->getIArguments()->emplace_back(v);
|
|
|
|
|
|
|
|
for (auto v: tArgs)
|
|
|
|
block->getTArguments()->emplace_back(v);
|
|
|
|
|
|
|
|
this->setContextPrototype(block);
|
|
|
|
this->setCustomOp(Node::buildOpByType(opType, (int) input.size(), (int) block->getIArguments()->size(), (int) block->getTArguments()->size(), opNum, &_scalar));
|
|
|
|
block->setOpDescriptor(this->getCustomOp()->getOpDescriptor());
|
|
|
|
} else if (opType == OpType_CUSTOM) {
|
|
|
|
if (this->getCustomOp()) {
|
|
|
|
auto block = new ContextPrototype(this->getCustomOp()->getOpDescriptor(), this->id(), false);
|
|
|
|
|
|
|
|
for (auto v: dimensions)
|
|
|
|
block->getAxis()->emplace_back(v);
|
|
|
|
|
|
|
|
for (auto v: iArgs)
|
|
|
|
block->getIArguments()->emplace_back(v);
|
|
|
|
|
|
|
|
for (auto v: tArgs)
|
|
|
|
block->getTArguments()->emplace_back(v);
|
|
|
|
|
|
|
|
this->setContextPrototype(block);
|
|
|
|
} else throw std::runtime_error("wrong custom operation given");
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::graph::Node::Node(const sd::graph::FlatNode *node) {
|
2019-06-06 14:21:15 +02:00
|
|
|
_hasExternalInputs = false;
|
|
|
|
_hasExternalOutputs = false;
|
|
|
|
_hasInternalInputs = false;
|
|
|
|
_hasInternalOutputs = false;
|
|
|
|
_extraParams = nullptr;
|
|
|
|
_dim = nullptr;
|
2020-03-02 10:49:41 +01:00
|
|
|
_dataType = sd::DataType::FLOAT32; // float as default
|
2019-06-06 14:21:15 +02:00
|
|
|
if (node->scope_id() != 0)
|
|
|
|
this->_scope_id = node->scope_id();
|
|
|
|
|
|
|
|
if (node->scope_name() != nullptr && node->scope_name()->size() > 0)
|
|
|
|
this->_scope_name = node->scope_name()->str();
|
|
|
|
|
|
|
|
if (node->scalar() != nullptr) {
|
2020-03-02 10:49:41 +01:00
|
|
|
auto scalar = sd::graph::FlatUtils::fromFlatArray(node->scalar());
|
2019-06-06 14:21:15 +02:00
|
|
|
_scalar = *scalar;
|
|
|
|
delete scalar;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (node != nullptr) {
|
|
|
|
this->_id = node->id();
|
|
|
|
//this->_dataType = DataTypeUtils::fromFlatDataType(node->dataType());
|
|
|
|
this->_opNum = node->opNum();
|
|
|
|
this->_opType = node->opType();
|
|
|
|
|
|
|
|
if (node->name() != nullptr && node->name()->c_str() != nullptr) {
|
|
|
|
this->_name = node->name()->str();
|
|
|
|
}
|
|
|
|
|
|
|
|
if (node->inputPaired() != nullptr && node->inputPaired()->size() > 0) {
|
|
|
|
for (int e = 0; e < (int) node->inputPaired()->size(); e++) {
|
|
|
|
auto pair = node->inputPaired()->Get(e);
|
|
|
|
pickInput(pair->first(), pair->second());
|
|
|
|
}
|
|
|
|
} else if (node->input() != nullptr && node->input()->size() > 0) {
|
|
|
|
for (int e = 0; e < (int) node->input()->size(); e++)
|
|
|
|
pickInput(node->input()->Get(e));
|
|
|
|
} else {
|
|
|
|
if (this->opType() != OpType_LOGIC) {
|
|
|
|
if (this->_name.size() > 0) {
|
|
|
|
nd4j_debug("Node [%i:<%s>] has no inputs defined\n", this->_id, this->_name.c_str());
|
|
|
|
} else {
|
|
|
|
nd4j_debug("Node [%i:<noname>] has no inputs defined\n", this->_id);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/*
|
|
|
|
if (node->output() != nullptr)
|
|
|
|
for (int e = 0; e < (int) node->output()->size(); e++) {
|
|
|
|
auto oid = node->output()->Get(e);
|
|
|
|
if (oid != this->_id && oid != 0) {
|
|
|
|
nd4j_verbose("Picking output: %i\n", node->output()->Get(e));
|
|
|
|
pickOutput(oid);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
|
|
if (node->extraParams() != nullptr && node->extraParams()->size() > 0) {
|
|
|
|
_extraParams = new double[node->extraParams()->size()];
|
|
|
|
for (int e = 0; e < (int) node->extraParams()->size(); e++) {
|
|
|
|
_extraParams[e] = static_cast<double>(node->extraParams()->Get(e));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if (node->dimensions() != nullptr && node->dimensions()->size() > 0) {
|
|
|
|
_dim = new int[node->dimensions()->size()];
|
|
|
|
for (int e = 0; e < (int) node->dimensions()->size(); e++) {
|
|
|
|
_dimensions.emplace_back(node->dimensions()->Get(e));
|
|
|
|
_dim[e] = node->dimensions()->Get(e);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if (this->opType() == OpType_LOGIC && this->opNum() == 100L) {
|
|
|
|
if (node->extraInteger()->size() < 1) {
|
|
|
|
nd4j_printf("Node_%i is type of Enter, but has no FrameID defined\n", this->id());
|
|
|
|
throw std::runtime_error("Enter node must have FrameID specified");
|
|
|
|
}
|
|
|
|
|
|
|
|
this->setFrameId(node->extraInteger()->Get(0));
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// these ops allow in-place execution by design
|
|
|
|
if (_opType == OpType_BROADCAST ||
|
|
|
|
_opType == OpType_BROADCAST_BOOL ||
|
|
|
|
_opType == OpType_INDEX_REDUCE ||
|
|
|
|
_opType == OpType_SUMMARYSTATS ||
|
|
|
|
_opType == OpType_REDUCE_BOOL ||
|
|
|
|
_opType == OpType_REDUCE_SAME ||
|
|
|
|
_opType == OpType_REDUCE_FLOAT ||
|
|
|
|
_opType == OpType_REDUCE_3 ||
|
|
|
|
_opType == OpType_TRANSFORM_STRICT ||
|
|
|
|
_opType == OpType_TRANSFORM_SAME ||
|
|
|
|
_opType == OpType_TRANSFORM_FLOAT ||
|
|
|
|
_opType == OpType_TRANSFORM_BOOL ||
|
|
|
|
_opType == OpType_RANDOM ||
|
|
|
|
_opType == OpType_PAIRWISE ||
|
|
|
|
_opType == OpType_PAIRWISE_BOOL ||
|
|
|
|
_opType == OpType_SCALAR_BOOL ||
|
|
|
|
_opType == OpType_SCALAR) {
|
|
|
|
|
|
|
|
if (_output.size() <= 1) {
|
|
|
|
_isInplace = true;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (node->input() != nullptr && node->input()->size() > 0) {
|
|
|
|
this->_isDeductable = true;
|
|
|
|
|
|
|
|
auto block = new ContextPrototype(nullptr, this->id(), false);
|
|
|
|
|
|
|
|
|
|
|
|
for (auto v: _dimensions)
|
|
|
|
block->getAxis()->emplace_back(v);
|
|
|
|
|
|
|
|
if (node->extraParams() != nullptr && node->extraParams()->size() > 0)
|
|
|
|
for (int e = 0; e < (int) node->extraParams()->size(); e++) {
|
|
|
|
block->getTArguments()->emplace_back(static_cast<double>(node->extraParams()->Get(e)));
|
|
|
|
}
|
|
|
|
|
|
|
|
if (node->extraBools() != nullptr && node->extraBools()->size() > 0)
|
|
|
|
for (int e = 0; e < (int) node->extraBools()->size(); e++) {
|
|
|
|
block->getBArguments()->push_back(node->extraBools()->Get(e));
|
|
|
|
}
|
|
|
|
|
|
|
|
if (node->extraInteger() != nullptr && node->extraInteger()->size() > 0)
|
|
|
|
for (int e = 0; e < (int) node->extraInteger()->size(); e++) {
|
|
|
|
block->getIArguments()->emplace_back(node->extraInteger()->Get(e));
|
|
|
|
}
|
|
|
|
|
2020-01-30 16:46:12 +01:00
|
|
|
if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) {
|
|
|
|
for (int e = 0; e < (int) node->extraTypes()->size(); e++) {
|
2020-03-02 10:49:41 +01:00
|
|
|
block->getDArguments()->emplace_back((sd::DataType) node->extraTypes()->Get(e));
|
2020-01-30 08:07:24 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
this->setContextPrototype(block);
|
|
|
|
this->setCustomOp(Node::buildOpByType(_opType, (int) node->input()->size(), (int) block->getIArguments()->size(), (int) block->getTArguments()->size(), (int) _opNum, &_scalar));
|
|
|
|
block->setOpDescriptor(this->getCustomOp()->getOpDescriptor());
|
|
|
|
} else if (node->inputPaired() != nullptr && node->inputPaired()->size() > 0) {
|
|
|
|
this->_isDeductable = true;
|
|
|
|
|
|
|
|
auto block = new ContextPrototype(nullptr, this->id(), false);
|
|
|
|
|
|
|
|
for (int e = 0; e < this->input()->size(); e++) {
|
|
|
|
block->inputs()->emplace_back(this->input()->at(e));
|
|
|
|
}
|
|
|
|
|
|
|
|
// there's no other IArgs in legacy options, actually
|
|
|
|
for (auto v: _dimensions)
|
|
|
|
block->getAxis()->emplace_back(v);
|
|
|
|
|
|
|
|
if (node->extraParams() != nullptr && node->extraParams()->size() > 0)
|
|
|
|
for (int e = 0; e < (int) node->extraParams()->size(); e++) {
|
|
|
|
block->getTArguments()->emplace_back(static_cast<double>(node->extraParams()->Get(e)));
|
|
|
|
}
|
|
|
|
|
|
|
|
if (node->extraBools() != nullptr && node->extraBools()->size() > 0)
|
|
|
|
for (int e = 0; e < (int) node->extraBools()->size(); e++) {
|
|
|
|
block->getBArguments()->push_back(node->extraBools()->Get(e));
|
|
|
|
}
|
|
|
|
|
|
|
|
if (node->extraInteger() != nullptr && node->extraInteger()->size() > 0)
|
|
|
|
for (int e = 0; e < (int) node->extraInteger()->size(); e++) {
|
|
|
|
block->getIArguments()->emplace_back(node->extraInteger()->Get(e));
|
|
|
|
}
|
|
|
|
|
2020-01-30 16:46:12 +01:00
|
|
|
if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) {
|
|
|
|
for (int e = 0; e < (int) node->extraTypes()->size(); e++) {
|
2020-03-02 10:49:41 +01:00
|
|
|
block->getDArguments()->emplace_back((sd::DataType) node->extraTypes()->Get(e));
|
2020-01-30 08:07:24 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
this->setContextPrototype(block);
|
|
|
|
|
|
|
|
this->setCustomOp(Node::buildOpByType(_opType, (int) node->inputPaired()->size(), (int) block->getIArguments()->size(), (int) block->getTArguments()->size(), (int) _opNum, &_scalar));
|
|
|
|
block->setOpDescriptor(this->getCustomOp()->getOpDescriptor());
|
|
|
|
}
|
|
|
|
} else if (this->_opType == OpType_CUSTOM) {
|
2020-06-06 14:26:55 +02:00
|
|
|
auto op = sd::ops::OpRegistrator::getInstance().getOperation(this->opNum());
|
2019-06-06 14:21:15 +02:00
|
|
|
if (op == nullptr) {
|
|
|
|
nd4j_verbose("Can't find operation: %lld\n", this->opNum());
|
|
|
|
throw std::runtime_error("Can't find requested operation");
|
|
|
|
}
|
|
|
|
|
|
|
|
auto block = new ContextPrototype(nullptr, this->id());
|
|
|
|
|
|
|
|
for (int e = 0; e < this->input()->size(); e++) {
|
|
|
|
block->inputs()->emplace_back(this->input()->at(e));
|
|
|
|
}
|
|
|
|
|
|
|
|
if (node->extraInteger() != nullptr)
|
|
|
|
for (uint32_t e = 0; e < node->extraInteger()->size(); e++) {
|
|
|
|
auto v = node->extraInteger()->Get(e);
|
|
|
|
// FIXME: remove this static_cast, iArgs should be Nd4jLong
|
|
|
|
block->getIArguments()->emplace_back(static_cast<int>(v));
|
|
|
|
}
|
|
|
|
|
|
|
|
if (node->extraParams() != nullptr)
|
|
|
|
for (uint32_t e = 0; e < node->extraParams()->size(); e++)
|
|
|
|
block->getTArguments()->emplace_back(static_cast<double>(node->extraParams()->Get(e)));
|
|
|
|
|
|
|
|
if (node->extraBools() != nullptr && node->extraBools()->size() > 0)
|
|
|
|
for (int e = 0; e < (int) node->extraBools()->size(); e++) {
|
|
|
|
block->getBArguments()->push_back(node->extraBools()->Get(e));
|
|
|
|
}
|
|
|
|
|
2020-01-30 16:46:12 +01:00
|
|
|
if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) {
|
|
|
|
for (int e = 0; e < (int) node->extraTypes()->size(); e++) {
|
2020-03-02 10:49:41 +01:00
|
|
|
block->getDArguments()->emplace_back((sd::DataType) node->extraTypes()->Get(e));
|
2020-01-30 08:07:24 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
for (auto v: _dimensions)
|
|
|
|
block->getAxis()->emplace_back(v);
|
|
|
|
|
|
|
|
this->setContextPrototype(block);
|
|
|
|
this->setCustomOp(op);
|
|
|
|
block->setOpDescriptor(this->getCustomOp()->getOpDescriptor());
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
// empty dynamic node, tests probably
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::DataType Node::dataType() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return _dataType;
|
|
|
|
}
|
|
|
|
|
|
|
|
ContextPrototype* Node::protoContext() {
|
|
|
|
return _protoContext;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::graph::Node::~Node() {
|
2019-06-06 14:21:15 +02:00
|
|
|
if (_extraParams != nullptr)
|
|
|
|
delete[] _extraParams;
|
|
|
|
|
|
|
|
if (_dim != nullptr)
|
|
|
|
delete[] _dim;
|
|
|
|
|
|
|
|
if (_protoContext != nullptr)
|
|
|
|
delete _protoContext;
|
|
|
|
|
2019-11-13 15:15:18 +01:00
|
|
|
if (_isDeductable && _customOp != nullptr) {
|
|
|
|
Node::deleteOpByType(_opType, _customOp);
|
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
int sd::graph::Node::getRewindNode() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return _rewindNode;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void sd::graph::Node::setRewindNode(int nodeId) {
|
2019-06-06 14:21:15 +02:00
|
|
|
_rewindNode = nodeId;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
std::pair<int, int>& sd::graph::Node::getRewindLayer() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return _rewindLayer;
|
|
|
|
};
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void sd::graph::Node::setRewindLayer(int layerId, int stepId) {
|
2019-06-06 14:21:15 +02:00
|
|
|
_rewindLayer.first = layerId;
|
|
|
|
_rewindLayer.second = stepId;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
bool sd::graph::Node::equals(Node *other) {
|
2019-06-06 14:21:15 +02:00
|
|
|
if (_opType == other->_opType && _dataType == other->_dataType && _opNum == other->_opNum)
|
|
|
|
return true;
|
|
|
|
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void sd::graph::Node::deleteOpByType(OpType opType, void *op) {
|
2019-11-13 15:15:18 +01:00
|
|
|
switch (opType) {
|
|
|
|
case OpType_PAIRWISE:
|
2020-03-02 10:49:41 +01:00
|
|
|
delete reinterpret_cast<sd::ops::LegacyPairwiseTransformOp*>(op);
|
2019-11-13 15:15:18 +01:00
|
|
|
break;
|
|
|
|
case OpType_PAIRWISE_BOOL:
|
2020-03-02 10:49:41 +01:00
|
|
|
delete reinterpret_cast<sd::ops::LegacyPairwiseTransformBoolOp*>(op);
|
2019-11-13 15:15:18 +01:00
|
|
|
break;
|
|
|
|
case OpType_TRANSFORM_STRICT:
|
2020-03-02 10:49:41 +01:00
|
|
|
delete reinterpret_cast<sd::ops::LegacyTransformStrictOp*>(op);
|
2019-11-13 15:15:18 +01:00
|
|
|
break;
|
|
|
|
case OpType_TRANSFORM_SAME:
|
2020-03-02 10:49:41 +01:00
|
|
|
delete reinterpret_cast<sd::ops::LegacyTransformSameOp*>(op);
|
2019-11-13 15:15:18 +01:00
|
|
|
break;
|
|
|
|
case OpType_TRANSFORM_FLOAT:
|
2020-03-02 10:49:41 +01:00
|
|
|
delete reinterpret_cast<sd::ops::LegacyTransformFloatOp*>(op);
|
2019-11-13 15:15:18 +01:00
|
|
|
break;
|
|
|
|
case OpType_TRANSFORM_BOOL:
|
2020-03-02 10:49:41 +01:00
|
|
|
delete reinterpret_cast<sd::ops::LegacyTransformBoolOp*>(op);
|
2019-11-13 15:15:18 +01:00
|
|
|
break;
|
|
|
|
case OpType_SCALAR:
|
2020-03-02 10:49:41 +01:00
|
|
|
delete reinterpret_cast<sd::ops::LegacyScalarOp*>(op);
|
2019-11-13 15:15:18 +01:00
|
|
|
break;
|
|
|
|
case OpType_SCALAR_BOOL:
|
2020-03-02 10:49:41 +01:00
|
|
|
delete reinterpret_cast<sd::ops::LegacyScalarBoolOp*>(op);
|
2019-11-13 15:15:18 +01:00
|
|
|
break;
|
|
|
|
case OpType_REDUCE_3:
|
2020-03-02 10:49:41 +01:00
|
|
|
delete reinterpret_cast<sd::ops::LegacyReduce3Op*>(op);
|
2019-11-13 15:15:18 +01:00
|
|
|
break;
|
|
|
|
case OpType_REDUCE_SAME:
|
2020-03-02 10:49:41 +01:00
|
|
|
delete reinterpret_cast<sd::ops::LegacyReduceSameOp*>(op);
|
2019-11-13 15:15:18 +01:00
|
|
|
break;
|
|
|
|
case OpType_REDUCE_FLOAT:
|
2020-03-02 10:49:41 +01:00
|
|
|
delete reinterpret_cast<sd::ops::LegacyReduceFloatOp*>(op);
|
2019-11-13 15:15:18 +01:00
|
|
|
break;
|
|
|
|
case OpType_REDUCE_LONG:
|
2020-03-02 10:49:41 +01:00
|
|
|
delete reinterpret_cast<sd::ops::LegacyReduceLongOp*>(op);
|
2019-11-13 15:15:18 +01:00
|
|
|
break;
|
|
|
|
case OpType_REDUCE_BOOL:
|
2020-03-02 10:49:41 +01:00
|
|
|
delete reinterpret_cast<sd::ops::LegacyReduceBoolOp*>(op);
|
2019-11-13 15:15:18 +01:00
|
|
|
break;
|
|
|
|
case OpType_INDEX_REDUCE:
|
2020-03-02 10:49:41 +01:00
|
|
|
delete reinterpret_cast<sd::ops::LegacyIndexReduceOp*>(op);
|
2019-11-13 15:15:18 +01:00
|
|
|
break;
|
|
|
|
case OpType_SUMMARYSTATS:
|
2020-03-02 10:49:41 +01:00
|
|
|
delete reinterpret_cast<sd::ops::LegacyStatsOp*>(op);
|
2019-11-13 15:15:18 +01:00
|
|
|
break;
|
|
|
|
case OpType_RANDOM:
|
2020-03-02 10:49:41 +01:00
|
|
|
delete reinterpret_cast<sd::ops::LegacyRandomOp*>(op);
|
2019-11-13 15:15:18 +01:00
|
|
|
break;
|
|
|
|
case OpType_BROADCAST:
|
2020-03-02 10:49:41 +01:00
|
|
|
delete reinterpret_cast<sd::ops::LegacyBroadcastOp*>(op);
|
2019-11-13 15:15:18 +01:00
|
|
|
break;
|
|
|
|
case OpType_BROADCAST_BOOL:
|
2020-03-02 10:49:41 +01:00
|
|
|
delete reinterpret_cast<sd::ops::LegacyBroadcastBoolOp*>(op);
|
2019-11-13 15:15:18 +01:00
|
|
|
break;
|
|
|
|
case OpType_CUSTOM:
|
2020-03-02 10:49:41 +01:00
|
|
|
delete reinterpret_cast<sd::ops::DeclarableOp*>(op);
|
2019-11-13 15:15:18 +01:00
|
|
|
break;
|
|
|
|
default:
|
|
|
|
throw std::runtime_error("Bad opType passed in");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::DeclarableOp* sd::graph::Node::buildOpByType(OpType opType, int numInputs, int numIArgs, int numTArgs, int opNum, NDArray *scalar) {
|
2019-06-06 14:21:15 +02:00
|
|
|
switch (opType) {
|
|
|
|
case OpType_PAIRWISE:
|
2020-03-02 10:49:41 +01:00
|
|
|
return new sd::ops::LegacyPairwiseTransformOp(opNum);
|
2019-06-06 14:21:15 +02:00
|
|
|
case OpType_PAIRWISE_BOOL:
|
2020-03-02 10:49:41 +01:00
|
|
|
return new sd::ops::LegacyPairwiseTransformBoolOp(opNum);
|
2019-06-06 14:21:15 +02:00
|
|
|
case OpType_TRANSFORM_STRICT:
|
2020-03-02 10:49:41 +01:00
|
|
|
return new sd::ops::LegacyTransformStrictOp(opNum);
|
2019-06-06 14:21:15 +02:00
|
|
|
case OpType_TRANSFORM_SAME:
|
2020-03-02 10:49:41 +01:00
|
|
|
return new sd::ops::LegacyTransformSameOp(opNum);
|
2019-06-06 14:21:15 +02:00
|
|
|
case OpType_TRANSFORM_FLOAT:
|
2020-03-02 10:49:41 +01:00
|
|
|
return new sd::ops::LegacyTransformFloatOp(opNum);
|
2019-06-06 14:21:15 +02:00
|
|
|
case OpType_TRANSFORM_BOOL:
|
2020-03-02 10:49:41 +01:00
|
|
|
return new sd::ops::LegacyTransformBoolOp(opNum);
|
2019-06-06 14:21:15 +02:00
|
|
|
case OpType_SCALAR:
|
2020-03-02 10:49:41 +01:00
|
|
|
return scalar == nullptr ? new sd::ops::LegacyScalarOp(opNum) : new sd::ops::LegacyScalarOp(opNum, *scalar);
|
2019-06-06 14:21:15 +02:00
|
|
|
case OpType_SCALAR_BOOL:
|
2020-03-02 10:49:41 +01:00
|
|
|
return scalar == nullptr ? new sd::ops::LegacyScalarBoolOp(opNum) : new sd::ops::LegacyScalarBoolOp(opNum, *scalar);
|
2019-06-06 14:21:15 +02:00
|
|
|
case OpType_REDUCE_3:
|
2020-03-02 10:49:41 +01:00
|
|
|
return new sd::ops::LegacyReduce3Op(opNum);
|
2019-06-06 14:21:15 +02:00
|
|
|
case OpType_REDUCE_SAME:
|
2020-03-02 10:49:41 +01:00
|
|
|
return new sd::ops::LegacyReduceSameOp(opNum);
|
2019-06-06 14:21:15 +02:00
|
|
|
case OpType_REDUCE_FLOAT:
|
2020-03-02 10:49:41 +01:00
|
|
|
return new sd::ops::LegacyReduceFloatOp(opNum);
|
2019-06-06 14:21:15 +02:00
|
|
|
case OpType_REDUCE_LONG:
|
2020-03-02 10:49:41 +01:00
|
|
|
return new sd::ops::LegacyReduceLongOp(opNum);
|
2019-06-06 14:21:15 +02:00
|
|
|
case OpType_REDUCE_BOOL:
|
2020-03-02 10:49:41 +01:00
|
|
|
return new sd::ops::LegacyReduceBoolOp(opNum);
|
2019-06-06 14:21:15 +02:00
|
|
|
case OpType_INDEX_REDUCE:
|
2020-03-02 10:49:41 +01:00
|
|
|
return new sd::ops::LegacyIndexReduceOp(opNum);
|
2019-06-06 14:21:15 +02:00
|
|
|
case OpType_SUMMARYSTATS:
|
2020-03-02 10:49:41 +01:00
|
|
|
return new sd::ops::LegacyStatsOp(opNum);
|
2019-06-06 14:21:15 +02:00
|
|
|
case OpType_RANDOM:
|
2020-03-02 10:49:41 +01:00
|
|
|
return new sd::ops::LegacyRandomOp(opNum);
|
2019-06-06 14:21:15 +02:00
|
|
|
case OpType_BROADCAST:
|
2020-03-02 10:49:41 +01:00
|
|
|
return new sd::ops::LegacyBroadcastOp(opNum);
|
2019-06-06 14:21:15 +02:00
|
|
|
case OpType_BROADCAST_BOOL:
|
2020-03-02 10:49:41 +01:00
|
|
|
return new sd::ops::LegacyBroadcastBoolOp(opNum);
|
2019-06-06 14:21:15 +02:00
|
|
|
default:
|
|
|
|
throw std::runtime_error("Bad opType passed in");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
bool Node::isDeductable() {
|
|
|
|
return _isDeductable;
|
|
|
|
}
|
|
|
|
|
|
|
|
void Node::setDeductable(bool reallyDeductable) {
|
|
|
|
_isDeductable = reallyDeductable;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
Node* Node::clone() {
|
2020-03-02 10:49:41 +01:00
|
|
|
if (this->_customOp && this->_opType == sd::graph::OpType_CUSTOM) {
|
2019-06-06 14:21:15 +02:00
|
|
|
auto clone = new Node(this->_customOp, _id);
|
|
|
|
clone->pullValues(this);
|
|
|
|
return clone;
|
|
|
|
}
|
|
|
|
else {
|
|
|
|
auto clone = new Node(_opType, _opNum, _id);
|
|
|
|
|
|
|
|
clone->pullValues(this);
|
|
|
|
|
|
|
|
// op time
|
|
|
|
if (!_isDeductable)
|
|
|
|
clone->_customOp = _customOp;
|
|
|
|
else {
|
2020-03-02 10:49:41 +01:00
|
|
|
auto c = dynamic_cast<sd::ops::LegacyOp*>(_customOp);
|
2019-06-06 14:21:15 +02:00
|
|
|
clone->_customOp = c->clone();
|
|
|
|
}
|
|
|
|
|
|
|
|
return clone;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|