cavis/libnd4j/include/ops/declarable/OpDescriptor.h

190 lines
7.0 KiB
C++

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef LIBND4J_OPDESCRIPTOR_H
#define LIBND4J_OPDESCRIPTOR_H
#include <string>
#include <vector>
#include <initializer_list>
#include <helpers/helper_hash.h>
#include <ops/InputType.h>
#include <graph/generated/node_generated.h>
#include <array/DataType.h>
namespace sd {
namespace ops {
/**
* This class is very basic info holder for ops. bean/pojo pretty much.
*
*/
class ND4J_EXPORT OpDescriptor {
protected:
// opNum for legacy XYZ ops
int _opNum = 0;
// opName for CustomOp
std::string _opName;
// hash is used for ops lookup in OpRegistrator
Nd4jLong _hash = -1;
// minimal required/expected number of inputs/outpus for this given op
int _numInputs = 1;
int _numOutputs = 1;
// enum for ops. deprecated. will be removed
sd::graph::OpClass _opClass;
// special flag for divergent ops - ops that CAN and WILL modify graph behavior. Literally: IF, CASE.
bool _divergent = false;
// flag, if this given op allows in-place execution
bool _allowsInplace = true;
// minimal required number of T-type arguments.
// -1 as value means: not limited, variable number of arguments
int _tArgs = 0;
// minimal required number of Integer-type arguments.
// -1 as value means: not limited, variable number of arguments
int _iArgs = 0;
// field for BooleanOps
bool _scalar = false;
// field for LogicOps
bool _logic = false;
// default InputType is numeric
InputType _inputType = InputType_NUMERIC;
bool _sameMode = false;
std::vector<sd::DataType> _allowedIns;
std::vector<sd::DataType> _allowedOuts;
// optional per-input configuration
MAP_IMPL<int, std::vector<sd::DataType>> _outputTypes;
MAP_IMPL<int, std::vector<sd::DataType>> _inputTypes;
// field for ops that allow data type override at runtime
bool _dtypeOverride = false;
bool checkDataTypesMatch(sd::DataType needle, std::vector<sd::DataType> &haystack) const;
public:
// default constructor
OpDescriptor(int numInputs, int numOutputs, std::string opName, bool allowsInplace);
// constructor for boolean ops
OpDescriptor(int numInputs, std::string opName, bool isScalar);
OpDescriptor(int numInputs, const char* opName, bool isScalar);
// default constructor
OpDescriptor(int numInputs, int numOutputs, const char *opName, bool allowsInplace);
// constructor for configurable op
OpDescriptor(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs);
// constructor for non-configurable divergent op
OpDescriptor(int numInputs, int numOutputs, std::string opName, bool allowsInplace, bool divergent);
// constructor for non-configurable divergent op
OpDescriptor(int numInputs, int numOutputs, const char *opName, bool allowsInplace, bool divergent);
// constructor for configurable divergent op
OpDescriptor(int numInputs, int numOutputs, const char *opName, bool allowsInplace, bool divergent, int tArgs, int iArgs);
// constructor for logical ops (while, scope, etc)
OpDescriptor(const char * opName, bool isLogic);
bool operator==(const OpDescriptor& other) const;
// default destructor
~OpDescriptor();
// this method returns minimal expected number of T arguments
int getNumberOfTArgs();
// this method returns minimal expected number of Integer arguments
int getNumberOfIArgs();
// this method returns minimal expected number of inputs
int getNumberOfInputs();
// this method returns hash code for this operation
Nd4jLong getHash();
// this method returns minimal expected number of outputs
int getNumberOfOutputs();
// this method returns opName (can be empty)
std::string *getOpName();
// returns TRUE if this op is divergent. FALSE otherwise
bool isDivergent();
// returns TRUE if this op allows in-place execution
bool allowsInplace();
// this method allows you to enable/disable inplace call for a given op
void allowInplace(bool reallyAllow);
// this method returns opNum (applicable for legacy XYZ ops only)
int getOpNum();
// this method allows to set specifc opNum
void setOpNum(int opNum);
void setHash(Nd4jLong hash);
InputType inputType();
OpDescriptor* setInputType(InputType type);
OpDescriptor* setAllowedInputTypes(const std::initializer_list<sd::DataType> &dtype);
OpDescriptor* setAllowedOutputTypes(const std::initializer_list<sd::DataType> &dtype);
OpDescriptor* setAllowedInputTypes(int index, const std::vector<sd::DataType> &dtype);
OpDescriptor* setAllowedOutputTypes(int index, const std::vector<sd::DataType> &dtype);
OpDescriptor* setAllowedInputTypes(int index, sd::DataType dtype);
OpDescriptor* setAllowedOutputTypes(int index, sd::DataType dtype);
OpDescriptor* setAllowedInputTypes(sd::DataType dtype);
OpDescriptor* setAllowedOutputTypes(sd::DataType dtype);
OpDescriptor* allowOverride(bool reallyAllow);
OpDescriptor* setSameMode(bool reallySame);
OpDescriptor* setInputType(int idx, sd::DataType dtype);
OpDescriptor* setOutputType(int idx, sd::DataType dtype);
std::vector<sd::DataType> getOutputTypesForOutput(int index);
bool checkInputMatch(int index, sd::DataType dataType);
bool checkOutputMatch(int index, sd::DataType dataType);
bool isSameMode();
bool isInherit(int index);
};
}
}
#endif //LIBND4J_OPDESCRIPTOR_H