213 lines
8.8 KiB
C++
213 lines
8.8 KiB
C++
/*******************************************************************************
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
*
|
|
* This program and the accompanying materials are made available under the
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* License for the specific language governing permissions and limitations
|
|
* under the License.
|
|
*
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
******************************************************************************/
|
|
|
|
//
|
|
// @author raver119@gmail.com
|
|
//
|
|
|
|
#ifndef LIBND4J_DECLARABLE_OPS_H
|
|
#define LIBND4J_DECLARABLE_OPS_H
|
|
|
|
#include <sstream>
|
|
#include <types/float16.h>
|
|
#include <pointercast.h>
|
|
#include <NDArray.h>
|
|
#include <graph/Context.h>
|
|
#include "OpDescriptor.h"
|
|
#include <helpers/helper_hash.h>
|
|
#include <array/ShapeList.h>
|
|
#include <array/ResultSet.h>
|
|
#include <helpers/OpArgsHolder.h>
|
|
#include <dll.h>
|
|
#include <ops/declarable/EmptyHandling.h>
|
|
//#include <ops/declarable/declarable_ops.h>
|
|
|
|
#include <chrono>
|
|
#include <ctime>
|
|
#include <mutex>
|
|
|
|
using namespace nd4j::graph;
|
|
|
|
namespace nd4j {
|
|
namespace ops {
|
|
|
|
Nd4jStatus ND4J_EXPORT conditionHelper(const char *file, int line, int condition, int argNumber, const char *format, ...);
|
|
|
|
|
|
template<typename T>
|
|
Nd4jStatus resultHelper(T status, const char *func, const char *file, int line) {
|
|
if (status) {
|
|
// TODO: fill out error codes here
|
|
fprintf(stderr, "Validation error at %s:%d code=%d(%s) \"%s\" \n", file, line,
|
|
static_cast<unsigned int>(status), "", func);
|
|
|
|
return ND4J_STATUS_BAD_INPUT;
|
|
}
|
|
|
|
return ND4J_STATUS_OK;
|
|
}
|
|
|
|
/**
|
|
* This class is the basic building block of Graph Operations. Any CustomOp out there is built on top of this "abstract" class.
|
|
*
|
|
*/
|
|
class ND4J_EXPORT DeclarableOp {
|
|
private:
|
|
std::mutex _registrator;
|
|
bool _registered = false;
|
|
|
|
protected:
|
|
OpDescriptor *_descriptor;
|
|
NDArray *_scalar = nullptr;
|
|
|
|
virtual void registerTypes();
|
|
|
|
/**
|
|
* This method executes this Op, and defined for most of individual ops separately
|
|
*/
|
|
virtual Nd4jStatus validateAndExecute(Context& block) = 0;
|
|
|
|
/**
|
|
* This method ensures that target variable has enough space for op execution
|
|
*
|
|
* TODO: we want workspaces support right here
|
|
*/
|
|
bool allocateResult(Context& block, std::initializer_list<Nd4jLong>& shape, char order = 'c');
|
|
bool allocateResult(Context& block, Nd4jLong* shape);
|
|
|
|
/**
|
|
* This method overwrites existen NDArray or NDArrayList in VariableSpace
|
|
*
|
|
* PLEASE NOTE: This method is dangerous.
|
|
*
|
|
* @param block
|
|
* @param numOutput
|
|
* @param array
|
|
*/
|
|
void overwriteResult(Context& block, int outputIdx, NDArray* array);
|
|
void overwriteResult(Context& block, int outputIdx, NDArrayList* list);
|
|
|
|
/*
|
|
* This method attaches array to specific Variable, identified by node ID and outputNumber (which is output index for multi-output operations)
|
|
*/
|
|
void storeResult(Context &block, int outputNumber, NDArray& array);
|
|
void storeResult(Context &block, int outputNumber, NDArray* array);
|
|
nd4j::NDArray* getZ(Context& block, int inputId = 0);
|
|
|
|
/**
|
|
* This method pre-allocates NDArrays for Op output, in case they are not available at op execution time
|
|
*/
|
|
int prepareOutputs(Context& block);
|
|
|
|
virtual samediff::EmptyHandling emptyHandling();
|
|
public:
|
|
// for special cases, like BooleanOps
|
|
DeclarableOp();
|
|
DeclarableOp(const char *name, int numInputs, bool scalar);
|
|
|
|
// regular constructors
|
|
DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace);
|
|
DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, bool divergent);
|
|
DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs);
|
|
|
|
// for LogicalOps
|
|
DeclarableOp(const char *name, bool isLogical);
|
|
|
|
// default testructor
|
|
virtual ~DeclarableOp();
|
|
|
|
// this method returns OpDescriptor, describing this Op instance
|
|
OpDescriptor *getOpDescriptor();
|
|
|
|
Nd4jStatus validateDataTypes(Context& block);
|
|
|
|
/**
|
|
* This method should be available in each implemented Op, and should return Op output shape(s), for a given input shape(s)
|
|
*/
|
|
virtual ShapeList* calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block) = 0;
|
|
|
|
/**
|
|
* Returns opName
|
|
*
|
|
* @return
|
|
*/
|
|
std::string *getOpName();
|
|
|
|
/**
|
|
* Returns opHash
|
|
*/
|
|
Nd4jLong getOpHash();
|
|
|
|
/**
|
|
* This method sets arguments for op
|
|
*/
|
|
// void setArguments();
|
|
|
|
/**
|
|
* This method returns pointer to results
|
|
*/
|
|
// void getResults();
|
|
|
|
/**
|
|
* This method executes given Op
|
|
*
|
|
* @param block
|
|
* @return 0 if OK, error code otherwise
|
|
*/
|
|
virtual Nd4jStatus execute(Context* block);
|
|
|
|
nd4j::ResultSet* execute(std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<Nd4jLong> iArgs, std::initializer_list<bool> bArgs, bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32);
|
|
Nd4jStatus execute(std::initializer_list<NDArray*> inputs, std::initializer_list<NDArray*> outputs , std::initializer_list<double> tArgs, std::initializer_list<Nd4jLong> iArgs, std::initializer_list<bool> bArgs, bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32);
|
|
Nd4jStatus execute(nd4j::graph::RandomGenerator& rng, std::initializer_list<NDArray*> inputs, std::initializer_list<NDArray*> outputs , std::initializer_list<double> tArgs, std::initializer_list<Nd4jLong> iArgs, std::initializer_list<bool> bArgs, bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32);
|
|
|
|
nd4j::ResultSet* execute(const std::vector<NDArray*>& inputs, const std::vector<double>& tArgs, const std::vector<Nd4jLong>& iArgs, const std::vector<bool>& bArgs = std::vector<bool>(), bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32);
|
|
Nd4jStatus execute(std::vector<NDArray*>& inputs, std::vector<NDArray*>& outputs , std::vector<double>& tArgs, std::vector<Nd4jLong>& iArgs, std::vector<bool>& bArgs, bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32);
|
|
Nd4jStatus execute(nd4j::graph::RandomGenerator& rng, std::vector<NDArray*>& inputs, std::vector<NDArray*>& outputs, std::vector<double>& tArgs, std::vector<Nd4jLong>& iArgs, std::vector<bool>& bArgs, bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32);
|
|
|
|
nd4j::ResultSet* execute(const nd4j::OpArgsHolder& holder, bool isInplace = false);
|
|
|
|
// There methods provide various validation options
|
|
Nd4jStatus validateNonEmptyInput(Context& block);
|
|
|
|
// this method checks if all input arrays have equal lengths
|
|
Nd4jStatus validateInputLengthMatch(Context& block);
|
|
|
|
// this method checks if all input arrays have the same shapes (orders/strides are NOT checked)
|
|
Nd4jStatus validateInputDimensionsMatch(Context& block);
|
|
|
|
// this method check if all input arrays have the same orders
|
|
Nd4jStatus validateOrdersMatch(Context& block);
|
|
|
|
// this method checks if all input arrays are 2D
|
|
Nd4jStatus validateInput2D(Context& block);
|
|
|
|
// this method checks if all input arrays are 3D
|
|
Nd4jStatus validateInput3D(Context& block);
|
|
|
|
// this method checks if all input arrays are 4D
|
|
Nd4jStatus validateInput4D(Context& block);
|
|
|
|
// this method checks if all input arrays are ND
|
|
Nd4jStatus validateInputDimensions(Context& block, int rank);
|
|
|
|
// this method checks if number of available arguments matches op expectations
|
|
Nd4jStatus validateArguments(Context& block);
|
|
};
|
|
}
|
|
}
|
|
|
|
#endif //LIBND4J_DECLARABLE_OPS_H
|