/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author raver119@gmail.com // #ifndef LIBND4J_OPDESCRIPTOR_H #define LIBND4J_OPDESCRIPTOR_H #include #include #include #include #include #include #include namespace sd { namespace ops { /** * This class is very basic info holder for ops. bean/pojo pretty much. * */ class ND4J_EXPORT OpDescriptor { protected: // opNum for legacy XYZ ops int _opNum = 0; // opName for CustomOp std::string _opName; // hash is used for ops lookup in OpRegistrator Nd4jLong _hash = -1; // minimal required/expected number of inputs/outpus for this given op int _numInputs = 1; int _numOutputs = 1; // enum for ops. deprecated. will be removed sd::graph::OpClass _opClass; // special flag for divergent ops - ops that CAN and WILL modify graph behavior. Literally: IF, CASE. bool _divergent = false; // flag, if this given op allows in-place execution bool _allowsInplace = true; // minimal required number of T-type arguments. // -1 as value means: not limited, variable number of arguments int _tArgs = 0; // minimal required number of Integer-type arguments. // -1 as value means: not limited, variable number of arguments int _iArgs = 0; // field for BooleanOps bool _scalar = false; // field for LogicOps bool _logic = false; // default InputType is numeric InputType _inputType = InputType_NUMERIC; bool _sameMode = false; std::vector _allowedIns; std::vector _allowedOuts; // optional per-input configuration MAP_IMPL> _outputTypes; MAP_IMPL> _inputTypes; // field for ops that allow data type override at runtime bool _dtypeOverride = false; bool checkDataTypesMatch(sd::DataType needle, std::vector &haystack) const; public: // default constructor OpDescriptor(int numInputs, int numOutputs, std::string opName, bool allowsInplace); // constructor for boolean ops OpDescriptor(int numInputs, std::string opName, bool isScalar); OpDescriptor(int numInputs, const char* opName, bool isScalar); // default constructor OpDescriptor(int numInputs, int numOutputs, const char *opName, bool allowsInplace); // constructor for configurable op OpDescriptor(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs); // constructor for non-configurable divergent op OpDescriptor(int numInputs, int numOutputs, std::string opName, bool allowsInplace, bool divergent); // constructor for non-configurable divergent op OpDescriptor(int numInputs, int numOutputs, const char *opName, bool allowsInplace, bool divergent); // constructor for configurable divergent op OpDescriptor(int numInputs, int numOutputs, const char *opName, bool allowsInplace, bool divergent, int tArgs, int iArgs); // constructor for logical ops (while, scope, etc) OpDescriptor(const char * opName, bool isLogic); bool operator==(const OpDescriptor& other) const; // default destructor ~OpDescriptor(); // this method returns minimal expected number of T arguments int getNumberOfTArgs(); // this method returns minimal expected number of Integer arguments int getNumberOfIArgs(); // this method returns minimal expected number of inputs int getNumberOfInputs(); // this method returns hash code for this operation Nd4jLong getHash(); // this method returns minimal expected number of outputs int getNumberOfOutputs(); // this method returns opName (can be empty) std::string *getOpName(); // returns TRUE if this op is divergent. FALSE otherwise bool isDivergent(); // returns TRUE if this op allows in-place execution bool allowsInplace(); // this method allows you to enable/disable inplace call for a given op void allowInplace(bool reallyAllow); // this method returns opNum (applicable for legacy XYZ ops only) int getOpNum(); // this method allows to set specifc opNum void setOpNum(int opNum); void setHash(Nd4jLong hash); InputType inputType(); OpDescriptor* setInputType(InputType type); OpDescriptor* setAllowedInputTypes(const std::initializer_list &dtype); OpDescriptor* setAllowedOutputTypes(const std::initializer_list &dtype); OpDescriptor* setAllowedInputTypes(int index, const std::vector &dtype); OpDescriptor* setAllowedOutputTypes(int index, const std::vector &dtype); OpDescriptor* setAllowedInputTypes(int index, sd::DataType dtype); OpDescriptor* setAllowedOutputTypes(int index, sd::DataType dtype); OpDescriptor* setAllowedInputTypes(sd::DataType dtype); OpDescriptor* setAllowedOutputTypes(sd::DataType dtype); OpDescriptor* allowOverride(bool reallyAllow); OpDescriptor* setSameMode(bool reallySame); OpDescriptor* setInputType(int idx, sd::DataType dtype); OpDescriptor* setOutputType(int idx, sd::DataType dtype); std::vector getOutputTypesForOutput(int index); bool checkInputMatch(int index, sd::DataType dataType); bool checkOutputMatch(int index, sd::DataType dataType); bool isSameMode(); bool isInherit(int index); }; } } #endif //LIBND4J_OPDESCRIPTOR_H