2019-06-06 14:21:15 +02:00
|
|
|
/*******************************************************************************
|
|
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
|
|
*
|
|
|
|
* This program and the accompanying materials are made available under the
|
|
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
|
|
*
|
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
|
|
* License for the specific language governing permissions and limitations
|
|
|
|
* under the License.
|
|
|
|
*
|
|
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
|
|
******************************************************************************/
|
|
|
|
|
|
|
|
//
|
|
|
|
// Created by raver119 on 13.10.2017.
|
|
|
|
//
|
|
|
|
|
|
|
|
#include <ops/declarable/OpDescriptor.h>
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
namespace sd {
|
2019-06-06 14:21:15 +02:00
|
|
|
namespace ops {
|
|
|
|
|
|
|
|
OpDescriptor::OpDescriptor(const char * opName, bool isLogic) {
|
|
|
|
_logic = isLogic;
|
|
|
|
_opName = opName;
|
|
|
|
}
|
|
|
|
|
|
|
|
OpDescriptor::OpDescriptor(int numInputs, const char * opName, bool isScalar) {
|
|
|
|
_numInputs = numInputs;
|
|
|
|
_numOutputs = 1;
|
|
|
|
|
|
|
|
_opName = opName;
|
2020-06-06 14:26:55 +02:00
|
|
|
_hash = sd::ops::HashHelper::getInstance().getLongHash(_opName);
|
2020-03-02 10:49:41 +01:00
|
|
|
_opClass = sd::graph::OpClass_CONDITIONAL;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
_scalar = isScalar;
|
|
|
|
}
|
|
|
|
|
|
|
|
OpDescriptor::OpDescriptor(int numInputs, std::string opName, bool isScalar) {
|
|
|
|
_numInputs = numInputs;
|
|
|
|
_numOutputs = 1;
|
|
|
|
|
|
|
|
_opName = opName;
|
2020-06-06 14:26:55 +02:00
|
|
|
_hash = sd::ops::HashHelper::getInstance().getLongHash(_opName);
|
2020-03-02 10:49:41 +01:00
|
|
|
_opClass = sd::graph::OpClass_CONDITIONAL;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
_scalar = isScalar;
|
|
|
|
}
|
|
|
|
|
2020-02-13 18:59:35 +01:00
|
|
|
void OpDescriptor::allowInplace(bool reallyAllow){
|
|
|
|
_allowsInplace = reallyAllow;
|
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
bool OpDescriptor::operator==(const OpDescriptor& other) const {
|
|
|
|
if (_hash == -1 && other._hash == -1)
|
|
|
|
return this->_opNum == other._opNum;
|
|
|
|
else
|
|
|
|
return this->_hash == other._hash;
|
|
|
|
}
|
|
|
|
|
|
|
|
OpDescriptor::OpDescriptor(int numInputs, int numOutputs, std::string opName, bool allowsInplace) : OpDescriptor::OpDescriptor(numInputs, numOutputs, opName.c_str(), allowsInplace) {
|
|
|
|
//
|
|
|
|
}
|
|
|
|
|
|
|
|
void OpDescriptor::setHash(Nd4jLong hash) {
|
|
|
|
_hash = hash;
|
|
|
|
}
|
|
|
|
|
|
|
|
// default constructor
|
|
|
|
OpDescriptor::OpDescriptor(int numInputs, int numOutputs, const char *opName, bool allowsInplace) {
|
|
|
|
_numInputs = numInputs;
|
|
|
|
_numOutputs = numOutputs;
|
|
|
|
|
|
|
|
std::string tmp(opName);
|
|
|
|
_opName = tmp;
|
|
|
|
_allowsInplace = allowsInplace;
|
2020-06-06 14:26:55 +02:00
|
|
|
_hash = sd::ops::HashHelper::getInstance().getLongHash(tmp);
|
2019-06-06 14:21:15 +02:00
|
|
|
_divergent = false;
|
|
|
|
|
|
|
|
// just default value
|
2020-03-02 10:49:41 +01:00
|
|
|
_opClass = sd::graph::OpClass_TRANSFORM;
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
// constructor for configurable op
|
|
|
|
OpDescriptor::OpDescriptor(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs) : OpDescriptor::OpDescriptor(numInputs, numOutputs, opName, allowsInplace) {
|
|
|
|
_tArgs = tArgs;
|
|
|
|
_iArgs = iArgs;
|
|
|
|
}
|
|
|
|
|
|
|
|
// constructor for non-configurable divergent op
|
|
|
|
OpDescriptor::OpDescriptor(int numInputs, int numOutputs, std::string opName, bool allowsInplace, bool divergent) : OpDescriptor::OpDescriptor(numInputs, numOutputs, opName.c_str(), allowsInplace, divergent) {
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
// constructor for non-configurable divergent op
|
|
|
|
OpDescriptor::OpDescriptor(int numInputs, int numOutputs, const char *opName, bool allowsInplace, bool divergent) : OpDescriptor::OpDescriptor(numInputs, numOutputs, opName, allowsInplace) {
|
|
|
|
_divergent = divergent;
|
|
|
|
}
|
|
|
|
|
|
|
|
// constructor for configurable divergent op
|
|
|
|
OpDescriptor::OpDescriptor(int numInputs, int numOutputs, const char *opName, bool allowsInplace, bool divergent, int tArgs, int iArgs) : OpDescriptor(numInputs, numOutputs, opName, allowsInplace, tArgs, iArgs) {
|
|
|
|
_divergent = divergent;
|
|
|
|
}
|
|
|
|
|
|
|
|
// default destructor
|
|
|
|
OpDescriptor::~OpDescriptor() {
|
|
|
|
//
|
|
|
|
}
|
|
|
|
|
|
|
|
int OpDescriptor::getNumberOfTArgs() {
|
|
|
|
return _tArgs;
|
|
|
|
}
|
|
|
|
|
|
|
|
int OpDescriptor::getNumberOfIArgs() {
|
|
|
|
return _iArgs;
|
|
|
|
}
|
|
|
|
|
|
|
|
int OpDescriptor::getNumberOfInputs() {
|
|
|
|
return _numInputs;
|
|
|
|
}
|
|
|
|
|
|
|
|
Nd4jLong OpDescriptor::getHash() {
|
|
|
|
return _hash;
|
|
|
|
}
|
|
|
|
|
|
|
|
int OpDescriptor::getNumberOfOutputs() {
|
|
|
|
return _numOutputs;
|
|
|
|
}
|
|
|
|
|
|
|
|
std::string * OpDescriptor::getOpName() {
|
|
|
|
return &_opName;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool OpDescriptor::isDivergent() {
|
|
|
|
return _divergent;
|
|
|
|
}
|
|
|
|
|
|
|
|
void OpDescriptor::setOpNum(int opNum) {
|
|
|
|
_opNum = opNum;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool OpDescriptor::allowsInplace() {
|
|
|
|
return _allowsInplace;
|
|
|
|
}
|
|
|
|
|
|
|
|
int OpDescriptor::getOpNum() {
|
|
|
|
return _opNum;
|
|
|
|
}
|
|
|
|
|
|
|
|
OpDescriptor* OpDescriptor::setInputType(const InputType type) {
|
|
|
|
_inputType = type;
|
|
|
|
return this;
|
|
|
|
}
|
|
|
|
|
|
|
|
InputType OpDescriptor::inputType() {
|
|
|
|
return _inputType;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
OpDescriptor* OpDescriptor::setAllowedInputTypes(const std::initializer_list<sd::DataType> &dtypes) {
|
2019-06-06 14:21:15 +02:00
|
|
|
_allowedIns = dtypes;
|
|
|
|
return this;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
OpDescriptor* OpDescriptor::setAllowedOutputTypes(const std::initializer_list<sd::DataType> &dtypes) {
|
2019-06-06 14:21:15 +02:00
|
|
|
_allowedOuts = dtypes;
|
|
|
|
return this;
|
|
|
|
}
|
|
|
|
|
2019-08-23 11:31:12 +02:00
|
|
|
OpDescriptor* OpDescriptor::allowOverride(bool allowOverride) {
|
|
|
|
_dtypeOverride = allowOverride;
|
|
|
|
return this;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
OpDescriptor* OpDescriptor::setAllowedInputTypes(const sd::DataType dtype) {
|
2019-06-06 14:21:15 +02:00
|
|
|
_allowedIns.clear();
|
|
|
|
_allowedIns.emplace_back(dtype);
|
|
|
|
return this;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
OpDescriptor* OpDescriptor::setAllowedOutputTypes(const sd::DataType dtype) {
|
2019-06-06 14:21:15 +02:00
|
|
|
_allowedOuts.clear();
|
|
|
|
_allowedOuts.emplace_back(dtype);
|
|
|
|
return this;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
OpDescriptor* OpDescriptor::setInputType(const int idx, const sd::DataType dtype) {
|
2019-06-06 14:21:15 +02:00
|
|
|
_inputTypes[idx] = { dtype };
|
|
|
|
return this;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
OpDescriptor* OpDescriptor::setOutputType(const int idx, const sd::DataType dtype) {
|
2019-06-06 14:21:15 +02:00
|
|
|
_outputTypes[idx] = { dtype };
|
|
|
|
return this;
|
|
|
|
}
|
|
|
|
|
|
|
|
OpDescriptor* OpDescriptor::setSameMode(const bool reallySame) {
|
|
|
|
_sameMode = reallySame;
|
|
|
|
return this;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
OpDescriptor* OpDescriptor::setAllowedInputTypes(int index, const std::vector<sd::DataType> &dtype) {
|
2019-06-06 14:21:15 +02:00
|
|
|
_inputTypes[index] = dtype;
|
|
|
|
return this;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
OpDescriptor* OpDescriptor::setAllowedOutputTypes(int index, const std::vector<sd::DataType> &dtype) {
|
2019-06-06 14:21:15 +02:00
|
|
|
_outputTypes[index] = dtype;
|
|
|
|
return this;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
OpDescriptor* OpDescriptor::setAllowedInputTypes(int index, sd::DataType dtype) {
|
2019-06-06 14:21:15 +02:00
|
|
|
if (_inputTypes.count(index) == 0)
|
|
|
|
_inputTypes[index] = {dtype};
|
|
|
|
else
|
|
|
|
_inputTypes[index].emplace_back(dtype);
|
|
|
|
|
|
|
|
return this;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
OpDescriptor* OpDescriptor::setAllowedOutputTypes(int index, sd::DataType dtype) {
|
2019-06-06 14:21:15 +02:00
|
|
|
if (_outputTypes.count(index) == 0)
|
|
|
|
_outputTypes[index] = {dtype};
|
|
|
|
else
|
|
|
|
_outputTypes[index].emplace_back(dtype);
|
|
|
|
|
|
|
|
return this;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
bool OpDescriptor::checkDataTypesMatch(sd::DataType needle, std::vector<sd::DataType> &haystack) const {
|
2019-06-06 14:21:15 +02:00
|
|
|
// if haystack is empty - INHERIT is occurs - any type is perfect?
|
|
|
|
if (haystack.empty())
|
|
|
|
return true;
|
|
|
|
|
|
|
|
// first we're checking for direct input type match
|
|
|
|
if (std::find(haystack.begin(), haystack.end(), needle) == haystack.end()) {
|
|
|
|
|
|
|
|
// if direct input match failed - we're checking for ANY as allowed input
|
2020-03-02 10:49:41 +01:00
|
|
|
if (std::find(haystack.begin(), haystack.end(), sd::DataType::ANY) == haystack.end())
|
2019-06-06 14:21:15 +02:00
|
|
|
return false;
|
|
|
|
else
|
|
|
|
return true;
|
|
|
|
} else {
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
bool OpDescriptor::checkInputMatch(int index, sd::DataType dataType) {
|
2019-06-06 14:21:15 +02:00
|
|
|
// we check for per-input types first
|
|
|
|
if (_inputTypes.empty() || _inputTypes.count(index) == 0) {
|
|
|
|
// checking global input types
|
|
|
|
return checkDataTypesMatch(dataType, _allowedIns);
|
|
|
|
} else {
|
|
|
|
// checking data type for specified input
|
|
|
|
auto allowed = _inputTypes[index];
|
|
|
|
return checkDataTypesMatch(dataType, allowed);
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
bool OpDescriptor::checkOutputMatch(int index, sd::DataType dataType) {
|
2019-06-06 14:21:15 +02:00
|
|
|
// we check for per-output types first
|
|
|
|
if (_outputTypes.empty() || _outputTypes.count(index) == 0) {
|
|
|
|
|
|
|
|
// checking global output types
|
|
|
|
return checkDataTypesMatch(dataType, _allowedOuts);
|
|
|
|
} else {
|
|
|
|
// checking data type for specified output
|
|
|
|
auto allowed = _outputTypes[index];
|
|
|
|
return checkDataTypesMatch(dataType, allowed);
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool OpDescriptor::isSameMode() {
|
|
|
|
return _sameMode;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool OpDescriptor::isInherit(int index) {
|
2020-03-02 10:49:41 +01:00
|
|
|
if (std::find(_allowedOuts.begin(), _allowedOuts.end(), sd::DataType::INHERIT) != _allowedOuts.end())
|
2019-06-06 14:21:15 +02:00
|
|
|
return true;
|
|
|
|
if (_outputTypes.count(index) > 0) {
|
|
|
|
auto vec = _outputTypes[index];
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
if (std::find(vec.begin(), vec.end(), sd::DataType::INHERIT) != vec.end())
|
2019-06-06 14:21:15 +02:00
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
std::vector<sd::DataType> OpDescriptor::getOutputTypesForOutput(int index) {
|
2019-06-06 14:21:15 +02:00
|
|
|
if (_outputTypes.count(index) > 0)
|
|
|
|
return _outputTypes.at(index);
|
|
|
|
else
|
2020-03-02 10:49:41 +01:00
|
|
|
return std::vector<sd::DataType>();
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|