2019-06-06 14:21:15 +02:00
|
|
|
/*******************************************************************************
|
|
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
|
|
*
|
|
|
|
* This program and the accompanying materials are made available under the
|
|
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
|
|
*
|
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
|
|
* License for the specific language governing permissions and limitations
|
|
|
|
* under the License.
|
|
|
|
*
|
|
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
|
|
******************************************************************************/
|
|
|
|
|
|
|
|
//
|
2020-01-30 08:07:24 +01:00
|
|
|
// @author raver119@gmail.com
|
2019-06-06 14:21:15 +02:00
|
|
|
//
|
|
|
|
|
|
|
|
#include <ops/declarable/DeclarableOp.h>
|
|
|
|
#include <Status.h>
|
|
|
|
#include <helpers/ShapeUtils.h>
|
|
|
|
#include <NDArrayFactory.h>
|
|
|
|
#include <exceptions/graph_exception.h>
|
|
|
|
#include <exceptions/unresolved_input_exception.h>
|
2019-09-11 20:50:28 +02:00
|
|
|
#include <ops/declarable/OpRegistrator.h>
|
2019-11-19 10:53:52 +01:00
|
|
|
#include <exceptions/datatype_exception.h>
|
|
|
|
#include <helpers/StringUtils.h>
|
2020-01-30 08:07:24 +01:00
|
|
|
#include <cstdarg>
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
namespace nd4j {
|
|
|
|
namespace ops {
|
|
|
|
Nd4jStatus conditionHelper(const char *file, int line, int condition, int argNumber, const char *format, ...) {
|
|
|
|
if (!condition) {
|
|
|
|
va_list args;
|
|
|
|
|
|
|
|
printf("Error at [%s:%i:%i]:\n", file, line, argNumber);
|
|
|
|
va_start(args, format);
|
|
|
|
vprintf(format, args);
|
|
|
|
va_end(args);
|
|
|
|
printf("\n");
|
|
|
|
fflush(stdout);
|
|
|
|
|
|
|
|
return ND4J_STATUS_BAD_PARAMS;
|
|
|
|
}
|
|
|
|
return ND4J_STATUS_OK;
|
|
|
|
}
|
|
|
|
|
|
|
|
DeclarableOp::DeclarableOp() {
|
|
|
|
// no-op
|
|
|
|
}
|
|
|
|
|
|
|
|
DeclarableOp::DeclarableOp(const char *name, bool isLogical) {
|
|
|
|
_descriptor = new OpDescriptor(name, isLogical);
|
|
|
|
}
|
|
|
|
|
|
|
|
DeclarableOp::DeclarableOp(const char *name, int numInputs, bool scalar) {
|
|
|
|
_descriptor = new OpDescriptor(numInputs, name, scalar);
|
|
|
|
}
|
|
|
|
|
|
|
|
DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace) {
|
|
|
|
_descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace);
|
|
|
|
}
|
|
|
|
|
|
|
|
DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, bool divergent) {
|
|
|
|
_descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace, divergent);
|
|
|
|
}
|
|
|
|
|
|
|
|
DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs) {
|
|
|
|
_descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace, tArgs, iArgs);
|
|
|
|
}
|
|
|
|
|
|
|
|
DeclarableOp::~DeclarableOp() {
|
|
|
|
if (_descriptor != nullptr)
|
|
|
|
delete _descriptor;
|
|
|
|
|
|
|
|
if (_scalar != nullptr)
|
|
|
|
delete _scalar;
|
|
|
|
}
|
|
|
|
|
|
|
|
OpDescriptor* DeclarableOp::getOpDescriptor() {
|
|
|
|
return _descriptor;
|
|
|
|
}
|
|
|
|
|
|
|
|
std::string *DeclarableOp::getOpName() {
|
|
|
|
return _descriptor->getOpName();
|
|
|
|
}
|
|
|
|
|
|
|
|
Nd4jLong DeclarableOp::getOpHash() {
|
|
|
|
return _descriptor->getHash();
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
nd4j::NDArray* nd4j::ops::DeclarableOp::getZ(Context& ctx, int inputId) {
|
|
|
|
NDArray* z = nullptr;
|
|
|
|
|
|
|
|
if (ctx.isFastPath()) {
|
|
|
|
if (ctx.fastpath_out().size() <= inputId) {
|
|
|
|
if (ctx.isInplace()) {
|
|
|
|
z = ctx.fastpath_in()[inputId];
|
|
|
|
} else
|
|
|
|
throw std::runtime_error("fastpath_out: unresolved output array");
|
|
|
|
} else {
|
|
|
|
z = ctx.fastpath_out()[inputId];
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
std::pair<int, int> pair(ctx.nodeId(), inputId);
|
|
|
|
|
|
|
|
if (ctx.isInplace()) {
|
|
|
|
z = ctx.variable(inputId)->getNDArray();
|
|
|
|
|
|
|
|
// hypothetically it's possible to have no variable. chances are low, but who knows. let's just create it for now
|
|
|
|
if (!ctx.getVariableSpace()->hasVariable(pair)) {
|
|
|
|
auto var = new Variable();
|
|
|
|
ctx.getVariableSpace()->putVariable(pair, var);
|
|
|
|
}
|
|
|
|
|
|
|
|
// now we're saving input array as output array
|
|
|
|
auto var = ctx.getVariableSpace()->getVariable(pair);
|
|
|
|
var->markRemovable(false);
|
|
|
|
var->setNDArray(z);
|
|
|
|
} else if (!ctx.isInplace()) {
|
|
|
|
auto var = ctx.variable(pair);
|
|
|
|
if (var->getNDArray() != nullptr && var->getNDArray()->nonNull()) {
|
|
|
|
z = var->getNDArray();
|
|
|
|
} else {
|
|
|
|
nd4j_printf("Can't get Z variable for node_%i!\n", ctx.nodeId());
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
nd4j_printf("BOOM!\n", "");
|
|
|
|
throw std::runtime_error("Boom!");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return z;
|
|
|
|
}
|
|
|
|
|
|
|
|
int nd4j::ops::DeclarableOp::prepareOutputs(Context &ctx) {
|
|
|
|
auto workspace = ctx.getWorkspace();
|
|
|
|
GraphProfile *prof = nullptr;
|
|
|
|
NodeProfile *node = nullptr;
|
|
|
|
std::chrono::time_point<std::chrono::system_clock> inputEnd, inputStart, shapeStart, shapeEnd, arrayStart, arrayEnd;
|
|
|
|
|
|
|
|
if (Environment::getInstance()->isProfiling()) {
|
|
|
|
if (ctx.getVariableSpace() != nullptr && ctx.getVariableSpace()->flowPath() != nullptr) {
|
|
|
|
prof = ctx.getVariableSpace()->flowPath()->profile();
|
|
|
|
node = prof->nodeById(ctx.nodeId());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if (ctx.isInplace()) {
|
|
|
|
// do nothing, getZ result will do the trick
|
|
|
|
return static_cast<int>(ctx.width());
|
|
|
|
} else {
|
|
|
|
// if op is not inplace - we should pre-allocate arrays
|
|
|
|
|
|
|
|
ShapeList inSha;
|
|
|
|
int results = 0;
|
|
|
|
|
|
|
|
if (Environment::getInstance()->isProfiling() && node != nullptr)
|
|
|
|
inputStart = std::chrono::system_clock::now();
|
|
|
|
|
|
|
|
int cntIn = 0;
|
|
|
|
// we build list of input shapes
|
|
|
|
if (ctx.isFastPath()) {
|
|
|
|
for (const auto p:ctx.fastpath_in()) {
|
2020-01-30 08:07:24 +01:00
|
|
|
if (p == nullptr)
|
|
|
|
continue;
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
inSha.push_back(p->getShapeInfo());
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
for (auto p: *ctx.inputs()) {
|
|
|
|
auto var = ctx.variable(p);
|
|
|
|
if (var->variableType() == VariableType::NDARRAY) {
|
|
|
|
NDArray *array = var->getNDArray();
|
|
|
|
if (array == nullptr)
|
|
|
|
throw unresolved_input_exception::build("Variable wasn't resolved prior shape calculation", p);
|
|
|
|
|
|
|
|
inSha.push_back(array->getShapeInfo());
|
|
|
|
}
|
|
|
|
cntIn++;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// optionally saving input time
|
|
|
|
if (Environment::getInstance()->isProfiling() && node != nullptr) {
|
|
|
|
inputEnd = std::chrono::system_clock::now();
|
|
|
|
auto inputTime = std::chrono::duration_cast<std::chrono::nanoseconds>(inputEnd - inputStart).count();
|
|
|
|
node->setInputTime(inputTime);
|
|
|
|
|
|
|
|
shapeStart = std::chrono::system_clock::now();
|
|
|
|
}
|
|
|
|
|
2020-01-04 07:06:44 +01:00
|
|
|
// if we override shape function, we'll return size of fastPath
|
|
|
|
if (ctx.isFastPath() && ctx.shapeFunctionOverride()) {
|
|
|
|
return (int) ctx.fastpath_out().size();
|
|
|
|
}
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
auto outSha = this->calculateOutputShape(&inSha, ctx);
|
|
|
|
results = outSha->size();
|
|
|
|
|
|
|
|
// optionally saving shapeTime
|
|
|
|
if (Environment::getInstance()->isProfiling() && node != nullptr) {
|
|
|
|
shapeEnd = std::chrono::system_clock::now();
|
|
|
|
auto prepTime = std::chrono::duration_cast<std::chrono::nanoseconds>(shapeEnd - shapeStart).count();
|
|
|
|
node->setShapeFunctionTime(prepTime);
|
|
|
|
|
|
|
|
arrayStart = std::chrono::system_clock::now();
|
|
|
|
}
|
|
|
|
|
|
|
|
int cnt = 0;
|
|
|
|
for (auto out: *outSha->asVector()) {
|
|
|
|
if (!ctx.isFastPath()) {
|
|
|
|
// we need to check, if Z is really needed
|
|
|
|
std::pair<int, int> pair(ctx.nodeId(), cnt++);
|
|
|
|
|
|
|
|
if (!ctx.isValueAvailable(pair.second)) {
|
|
|
|
if (Environment::getInstance()->isDebugAndVerbose())
|
|
|
|
shape::printShapeInfoLinear("Going to create variable with shape", out);
|
|
|
|
|
|
|
|
auto outArr = new NDArray(out, true, ctx.launchContext());
|
|
|
|
|
|
|
|
ctx.pushNDArrayToVariableSpace(pair, outArr);
|
|
|
|
} else {
|
|
|
|
// validate/compare shapes here. existent vs provided in outSha
|
|
|
|
auto var = ctx.variable(pair);
|
|
|
|
auto shape = var->getNDArray()->shapeInfo();
|
|
|
|
|
2019-10-23 11:11:25 +02:00
|
|
|
if (!shape::equalsSoft(out, shape) || shape::isEmpty(out) != shape::isEmpty(shape)) {
|
2019-06-06 14:21:15 +02:00
|
|
|
auto eShape = ShapeUtils::shapeAsString(out);
|
|
|
|
auto aShape = ShapeUtils::shapeAsString(shape);
|
|
|
|
|
|
|
|
//outSha->destroy();
|
|
|
|
delete outSha;
|
|
|
|
|
|
|
|
nd4j_printf("Expected vs provided shapes mismatch %s vs %s at index %i\n", eShape.c_str(), aShape.c_str(), pair.second);
|
|
|
|
throw std::runtime_error("Expected vs provided shapes mismatch");
|
|
|
|
}
|
2019-11-19 10:53:52 +01:00
|
|
|
|
|
|
|
/*
|
|
|
|
* FIXME: we want to uncomment this eventually, and check data types equality
|
|
|
|
//checking out data type equality
|
|
|
|
if (ArrayOptions::dataType(out) != ArrayOptions::dataType(shape)) {
|
|
|
|
std::string msg = "Provided array [" + StringUtils::valueToString<int>(pair.second) + "] has unexpected data type";
|
|
|
|
throw nd4j::datatype_exception::build(msg, ArrayOptions::dataType(out), ArrayOptions::dataType(shape));
|
|
|
|
}
|
|
|
|
*/
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
} else {
|
|
|
|
auto fout = ctx.fastpath_out();
|
|
|
|
auto idx = cnt++;
|
|
|
|
if (fout.size() <= idx) {
|
|
|
|
// array doesnt exist
|
|
|
|
auto outArr = new NDArray(out, true, ctx.launchContext());
|
|
|
|
ctx.setOutputArray(idx, outArr, true);
|
|
|
|
} else {
|
|
|
|
auto array = fout[idx];
|
2019-11-19 10:53:52 +01:00
|
|
|
// checking out shape equality
|
2019-10-23 11:11:25 +02:00
|
|
|
if (!shape::equalsSoft(out, array->shapeInfo()) || shape::isEmpty(out) != array->isEmpty()) {
|
2019-06-06 14:21:15 +02:00
|
|
|
auto eShape = ShapeUtils::shapeAsString(out);
|
|
|
|
auto aShape = ShapeUtils::shapeAsString(array->shapeInfo());
|
|
|
|
|
|
|
|
//outSha->destroy();
|
|
|
|
delete outSha;
|
|
|
|
|
|
|
|
nd4j_printf("Expected vs provided shape mismatch %s vs %s at index %i\n", eShape.c_str(), aShape.c_str(), idx);
|
|
|
|
throw std::runtime_error("Expected vs provided shape mismatch");
|
|
|
|
}
|
2019-11-19 10:53:52 +01:00
|
|
|
|
|
|
|
/*
|
|
|
|
* FIXME: we want to uncomment this eventually, and check data types equality
|
|
|
|
//checking out data type equality
|
|
|
|
if (ArrayOptions::dataType(out) != array->dataType()) {
|
|
|
|
std::string msg = "Provided array [" + StringUtils::valueToString<int>(idx) + "] has unexpected data type";
|
|
|
|
throw nd4j::datatype_exception::build(msg, ArrayOptions::dataType(out), array->dataType());
|
|
|
|
}
|
|
|
|
*/
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
//outSha->destroy();
|
|
|
|
delete outSha;
|
|
|
|
|
|
|
|
// saving arrayTime
|
|
|
|
if (Environment::getInstance()->isProfiling() && node != nullptr) {
|
|
|
|
arrayEnd = std::chrono::system_clock::now();
|
|
|
|
auto arrayTime = std::chrono::duration_cast<std::chrono::nanoseconds>(arrayEnd - arrayStart).count();
|
|
|
|
node->setArrayTime(arrayTime);
|
|
|
|
}
|
|
|
|
|
|
|
|
return results;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void nd4j::ops::DeclarableOp::storeResult(Context &block, int outputNumber, NDArray* array) {
|
|
|
|
this->storeResult(block, outputNumber, *array);
|
|
|
|
}
|
|
|
|
|
|
|
|
void nd4j::ops::DeclarableOp::storeResult(nd4j::graph::Context &ctx, int outputNumber, NDArray& array) {
|
|
|
|
ctx.pushNDArrayToVariableSpace(ctx.nodeId(), outputNumber, &array, !ctx.isInplace());
|
|
|
|
}
|
|
|
|
|
|
|
|
bool nd4j::ops::DeclarableOp::allocateResult(Context& block, Nd4jLong* shape) {
|
|
|
|
auto var = block.variable(block.getNodeId(), 0);
|
|
|
|
|
|
|
|
auto workspace = block.getWorkspace();
|
|
|
|
|
|
|
|
Nd4jLong len = shape::length(shape);
|
|
|
|
Nd4jLong* __shape;
|
|
|
|
ALLOCATE(__shape, workspace, shape::shapeInfoLength(shape), Nd4jLong); //new int[shape[0] * 2 + 4];
|
|
|
|
|
|
|
|
memcpy(__shape, shape, shape::shapeInfoByteLength(shape));
|
|
|
|
|
|
|
|
// if that's first run - we probably have nothing here
|
|
|
|
if (var->getNDArray() == nullptr) {
|
|
|
|
|
|
|
|
std::shared_ptr<DataBuffer> buffer = std::make_shared<DataBuffer>(len * sizeof(int8_t), ArrayOptions::dataType(__shape), workspace);
|
|
|
|
var->setNDArray(new NDArray(buffer, ShapeDescriptor(__shape), block.launchContext()));
|
|
|
|
}
|
|
|
|
else if(var->getNDArray()->lengthOf() != len) {
|
|
|
|
// if length not match - lets reallocate array
|
|
|
|
delete var->getNDArray();
|
|
|
|
std::shared_ptr<DataBuffer> buffer = std::make_shared<DataBuffer>(len * sizeof(int8_t), ArrayOptions::dataType(__shape), workspace);
|
|
|
|
var->setNDArray(new NDArray(buffer, ShapeDescriptor(__shape), block.launchContext()));
|
|
|
|
}
|
|
|
|
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
bool nd4j::ops::DeclarableOp::allocateResult(Context& block, std::initializer_list<Nd4jLong>& shape, char order) {
|
|
|
|
auto var = block.variable(block.getNodeId(), 0);
|
|
|
|
auto workspace = block.getWorkspace();
|
|
|
|
|
|
|
|
Nd4jLong len = shape::length(shape);
|
|
|
|
// if that's first run - we probably have nothing here
|
|
|
|
if (var->getNDArray() == nullptr) {
|
|
|
|
var->setNDArray(new NDArray(order, shape, block.dataType(), block.launchContext()));
|
|
|
|
} else if(var->getNDArray()->lengthOf() != len) {
|
|
|
|
// if length not match - lets reallocate array
|
|
|
|
delete var->getNDArray();
|
|
|
|
var->setNDArray(new NDArray(order, shape, block.dataType(), block.launchContext()));
|
|
|
|
}
|
|
|
|
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
|
|
|
Nd4jStatus nd4j::ops::DeclarableOp::validateDataTypes(Context& block) {
|
|
|
|
_registrator.lock();
|
|
|
|
if (!_registered) {
|
|
|
|
_registered = true;
|
|
|
|
this->registerTypes();
|
|
|
|
}
|
|
|
|
_registrator.unlock();
|
|
|
|
|
|
|
|
// rolling over inputs first
|
|
|
|
int cnt = 0, inT = 0;
|
|
|
|
std::vector<nd4j::DataType> inputTypes(block.width());
|
2019-08-26 18:57:51 +02:00
|
|
|
if (block.isFastPath()) {
|
|
|
|
for (auto array: block.fastpath_in()) {
|
2020-01-30 08:07:24 +01:00
|
|
|
if (array == nullptr)
|
|
|
|
continue;
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
inputTypes[inT++] = array->dataType();
|
|
|
|
if (!_descriptor->checkInputMatch(cnt, array->dataType())) {
|
|
|
|
auto ctype = DataTypeUtils::asString(array->dataType());
|
2019-08-26 18:57:51 +02:00
|
|
|
nd4j_printf("Op [%s] failed check for input [%i], DataType: [%s]\n",
|
|
|
|
_descriptor->getOpName()->data(), cnt, ctype.c_str());
|
2019-06-06 14:21:15 +02:00
|
|
|
return ND4J_STATUS_BAD_ARGUMENTS;
|
|
|
|
}
|
2019-08-26 18:57:51 +02:00
|
|
|
cnt++;
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
2019-08-26 18:57:51 +02:00
|
|
|
} else {
|
|
|
|
for (auto &p: *(block.inputs())) {
|
|
|
|
auto var = block.variable(p);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
// we're not checking validity, if ANY types were explicitly allowed
|
|
|
|
//if (block.dataType(cnt) == nd4j::DataType::ANY)
|
|
|
|
// continue;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
// only validating non-null variables
|
|
|
|
if (var != nullptr && var->hasNDArray()) {
|
|
|
|
auto array = var->getNDArray();
|
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
inputTypes[inT++] = array->dataType();
|
|
|
|
if (!_descriptor->checkInputMatch(cnt, array->dataType())) {
|
|
|
|
auto ctype = DataTypeUtils::asString(array->dataType());
|
|
|
|
nd4j_printf("Op [%s] failed check for input [%i], DataType: [%s]\n",
|
|
|
|
_descriptor->getOpName()->data(), cnt, ctype.c_str());
|
|
|
|
return ND4J_STATUS_BAD_ARGUMENTS;
|
|
|
|
}
|
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
cnt++;
|
|
|
|
}
|
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
if (block.isFastPath()) {
|
|
|
|
int index = 0;
|
|
|
|
for (auto array: block.fastpath_out()) {
|
2020-01-30 08:07:24 +01:00
|
|
|
if (array == nullptr)
|
|
|
|
continue;
|
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
auto cType = array->dataType();
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
if (_descriptor->isSameMode()) {
|
|
|
|
|
|
|
|
if (index >= block.width()) {
|
2019-08-27 20:00:38 +02:00
|
|
|
if (block.fastpath_in().size() == 0)
|
|
|
|
continue;
|
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
auto ia = block.fastpath_in()[0];
|
|
|
|
|
|
|
|
if (ia->dataType() != cType) {
|
2019-06-06 14:21:15 +02:00
|
|
|
auto t = DataTypeUtils::asString(cType);
|
2019-08-26 18:57:51 +02:00
|
|
|
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n",
|
|
|
|
_descriptor->getOpName()->data(), index, t.c_str());
|
2019-06-06 14:21:15 +02:00
|
|
|
return ND4J_STATUS_BAD_ARGUMENTS;
|
|
|
|
}
|
2019-08-26 18:57:51 +02:00
|
|
|
} else {
|
|
|
|
// for same mode, output type must be the same as input type
|
|
|
|
auto ia = block.fastpath_in()[index];
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
if (ia->dataType() != cType) {
|
|
|
|
auto t = DataTypeUtils::asString(cType);
|
|
|
|
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n",
|
|
|
|
_descriptor->getOpName()->data(), index, t.c_str());
|
|
|
|
return ND4J_STATUS_BAD_ARGUMENTS;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} else if (_descriptor->isInherit(index)) {
|
|
|
|
// in inherit mode, output type must be the same as one of input types
|
|
|
|
if (std::find(inputTypes.begin(), inputTypes.end(), cType) == inputTypes.end()) {
|
2019-06-06 14:21:15 +02:00
|
|
|
auto t = DataTypeUtils::asString(cType);
|
2019-08-26 18:57:51 +02:00
|
|
|
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s].\n",
|
|
|
|
_descriptor->getOpName()->data(), index, t.c_str());
|
2019-06-06 14:21:15 +02:00
|
|
|
return ND4J_STATUS_BAD_ARGUMENTS;
|
|
|
|
}
|
2019-08-26 18:57:51 +02:00
|
|
|
|
|
|
|
} else if (!_descriptor->checkOutputMatch(index, cType)) {
|
|
|
|
auto t = DataTypeUtils::asString(cType);
|
|
|
|
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s];\n",
|
|
|
|
_descriptor->getOpName()->data(), index, t.c_str());
|
|
|
|
return ND4J_STATUS_BAD_ARGUMENTS;
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
2019-08-26 18:57:51 +02:00
|
|
|
index++;
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
// checking optionally available outputs
|
|
|
|
auto varSpace = block.getVariableSpace();
|
|
|
|
for (int index = 0; index < DataTypeUtils::max<int>(); index++) {
|
|
|
|
if (varSpace != nullptr && varSpace->hasVariable(block.nodeId(), index)) {
|
|
|
|
auto var = block.variable(block.nodeId(), index);
|
|
|
|
|
|
|
|
// only validating non-null variables
|
|
|
|
if (var != nullptr && var->hasNDArray()) {
|
|
|
|
auto array = var->getNDArray();
|
|
|
|
auto cType = array->dataType();
|
|
|
|
|
|
|
|
if (_descriptor->isSameMode()) {
|
|
|
|
|
|
|
|
if (index >= block.width()) {
|
2019-08-27 20:00:38 +02:00
|
|
|
if (block.width() == 0)
|
|
|
|
continue;
|
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
auto iv = block.variable(0);
|
|
|
|
|
|
|
|
if (iv->getNDArray()->dataType() != cType) {
|
|
|
|
auto t = DataTypeUtils::asString(cType);
|
|
|
|
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n",
|
|
|
|
_descriptor->getOpName()->data(), index, t.c_str());
|
|
|
|
return ND4J_STATUS_BAD_ARGUMENTS;
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
// for same mode, output type must be the same as input type
|
|
|
|
auto iv = block.variable(index);
|
|
|
|
|
|
|
|
if (iv->getNDArray()->dataType() != cType) {
|
|
|
|
auto t = DataTypeUtils::asString(cType);
|
|
|
|
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n",
|
|
|
|
_descriptor->getOpName()->data(), index, t.c_str());
|
|
|
|
return ND4J_STATUS_BAD_ARGUMENTS;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} else if (_descriptor->isInherit(index)) {
|
|
|
|
// in inherit mode, output type must be the same as one of input types
|
|
|
|
if (std::find(inputTypes.begin(), inputTypes.end(), cType) == inputTypes.end()) {
|
|
|
|
auto t = DataTypeUtils::asString(cType);
|
|
|
|
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s].\n",
|
|
|
|
_descriptor->getOpName()->data(), index, t.c_str());
|
|
|
|
return ND4J_STATUS_BAD_ARGUMENTS;
|
|
|
|
}
|
|
|
|
|
|
|
|
} else if (!_descriptor->checkOutputMatch(index, cType)) {
|
|
|
|
auto t = DataTypeUtils::asString(cType);
|
|
|
|
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s];\n",
|
|
|
|
_descriptor->getOpName()->data(), index, t.c_str());
|
|
|
|
return ND4J_STATUS_BAD_ARGUMENTS;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} else
|
|
|
|
break;
|
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return ND4J_STATUS_OK;
|
|
|
|
}
|
|
|
|
|
|
|
|
Nd4jStatus nd4j::ops::DeclarableOp::execute(Context* block) {
|
|
|
|
nd4j_debug("Executing op: [%s]\n", this->getOpName()->c_str());
|
|
|
|
|
|
|
|
std::chrono::time_point<std::chrono::system_clock> timeEnter, timeStart, timeEnd;
|
|
|
|
Nd4jLong prepTime, outerTime;
|
|
|
|
|
|
|
|
Nd4jLong memoryBefore = block->workspace() == nullptr ? 0L : block->workspace()->getSpilledSize() + block->workspace()->getUsedSize();
|
|
|
|
if (Environment::getInstance()->isProfiling())
|
|
|
|
timeEnter = std::chrono::system_clock::now();
|
|
|
|
|
|
|
|
// basic validation: ensure inputs are set
|
|
|
|
REQUIRE_OK(this->validateNonEmptyInput(*block));
|
|
|
|
|
|
|
|
// ensure number of IArgs, TArgs match our expectations
|
|
|
|
REQUIRE_OK(this->validateArguments(*block));
|
|
|
|
|
|
|
|
// validating data types for inputs and (optionally) outputs
|
|
|
|
REQUIRE_OK(this->validateDataTypes(*block));
|
|
|
|
|
|
|
|
|
|
|
|
// this method will allocate output NDArrays for this op
|
|
|
|
auto numOutputs = this->prepareOutputs(*block);
|
|
|
|
|
|
|
|
if (Environment::getInstance()->isProfiling()) {
|
|
|
|
timeStart = std::chrono::system_clock::now();
|
|
|
|
prepTime = std::chrono::duration_cast<std::chrono::nanoseconds>(timeStart - timeEnter).count();
|
|
|
|
}
|
|
|
|
|
2019-09-11 20:50:28 +02:00
|
|
|
|
|
|
|
Nd4jStatus status;
|
|
|
|
bool hasHelper = false;
|
|
|
|
|
2019-11-14 12:35:02 +01:00
|
|
|
// platform helpers use might be forbidden for various reasons, so we'll check it out first
|
|
|
|
if (block->helpersAllowed() && nd4j::Environment::getInstance()->helpersAllowed()) {
|
|
|
|
// if we have platform-specific helper for this op - invoke it
|
2020-01-20 19:32:46 +01:00
|
|
|
if (OpRegistrator::getInstance()->hasHelper(this->getOpHash(), block->engine())) {
|
|
|
|
auto helper = OpRegistrator::getInstance()->getPlatformHelper(this->getOpHash(), block->engine());
|
2019-11-14 12:35:02 +01:00
|
|
|
if (helper->isUsable(*block)) {
|
|
|
|
status = helper->invokeHelper(*block);
|
|
|
|
hasHelper = true;
|
|
|
|
}
|
2019-09-11 20:50:28 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// if we don't have platform-specific helper - invoke generic implementation
|
|
|
|
if (!hasHelper)
|
|
|
|
status = this->validateAndExecute(*block);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
// optionally saving execution time
|
|
|
|
if (Environment::getInstance()->isProfiling()) {
|
|
|
|
timeEnd = std::chrono::system_clock::now();
|
|
|
|
outerTime = std::chrono::duration_cast<std::chrono::nanoseconds>(timeEnd - timeStart).count();
|
|
|
|
block->setInnerTime(outerTime);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (Environment::getInstance()->isProfiling()) {
|
|
|
|
auto fp = block->getVariableSpace()->flowPath();
|
|
|
|
if (fp != nullptr) {
|
|
|
|
auto p = fp->profile();
|
|
|
|
if (p != nullptr) {
|
|
|
|
Nd4jLong memoryAfter = block->workspace() == nullptr ? 0L : block->workspace()->getSpilledSize() + block->workspace()->getUsedSize();
|
|
|
|
Nd4jLong memoryUsed = memoryAfter - memoryBefore;
|
|
|
|
p->nodeById(block->nodeId())->setPreparationTime(prepTime);
|
|
|
|
p->nodeById(block->nodeId())->setExecutionTime(outerTime);
|
|
|
|
p->nodeById(block->nodeId())->setTotalSize(memoryUsed);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// now we print out all outputs for this node
|
|
|
|
if (nd4j::Environment::getInstance()->isDebugAndVerbose()) {
|
|
|
|
auto vs = block->getVariableSpace();
|
|
|
|
|
|
|
|
for (int e = 0; e < numOutputs; e++) {
|
|
|
|
// if given output index doesn't exist - we're done
|
|
|
|
|
|
|
|
if (!block->isFastPath()) {
|
|
|
|
if (!vs->hasVariable(block->nodeId(), e))
|
|
|
|
break;
|
|
|
|
} else {
|
|
|
|
// we have to check either in or out stack, depending on isInplace()
|
|
|
|
if (block->isInplace()) {
|
|
|
|
if (block->fastpath_in().size() <= e)
|
|
|
|
break;
|
|
|
|
} else {
|
|
|
|
if (block->fastpath_out().size() <= e)
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
auto array = block->isFastPath() ? block->isInplace() ? block->fastpath_in()[e] : block->fastpath_out()[e] : vs->getVariable(block->nodeId(), e)->getNDArray();
|
|
|
|
|
|
|
|
auto shape = ShapeUtils::shapeAsString(array);
|
|
|
|
auto first = array->isEmpty() ? std::string("Empty NDArray") : array->asString(32);
|
|
|
|
auto type = DataTypeUtils::asString(array->dataType());
|
|
|
|
|
|
|
|
nd4j_printf("node_%i:%i result shape: %s; dtype: %s; first values %s\n", block->nodeId(), e, shape.c_str(), type.c_str(), first.c_str());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return status;
|
|
|
|
}
|
|
|
|
|
|
|
|
void DeclarableOp::overwriteResult(Context &block, int outputIdx, NDArray *array) {
|
|
|
|
throw std::runtime_error("Overwrite result used!");
|
|
|
|
//block.pushNDArrayToVariableSpace(block.nodeId(), outputIdx, array);
|
|
|
|
/*
|
|
|
|
auto varSpace = block.getVariableSpace();
|
|
|
|
if (varSpace->hasVariable(block.getNodeId(), outputIdx)) {
|
|
|
|
auto var = varSpace->getVariable(block.getNodeId(), outputIdx);
|
|
|
|
if (var->getNDArray() != nullptr && var->isRemovable())
|
|
|
|
delete var->getNDArray();
|
|
|
|
|
|
|
|
var->setNDArray(array);
|
|
|
|
var->markRemovable(true);
|
|
|
|
} else {
|
|
|
|
auto var = new Variable(array, nullptr, block.getNodeId(), outputIdx);
|
|
|
|
varSpace->putVariable(block.getNodeId(), outputIdx, var);
|
|
|
|
}
|
|
|
|
*/
|
|
|
|
}
|
|
|
|
|
|
|
|
void DeclarableOp::overwriteResult(Context &block, int outputIdx, NDArrayList *list) {
|
|
|
|
throw std::runtime_error("Overwrite result used!");
|
|
|
|
//block.pushNDArrayListToVariableSpace(block.nodeId(), outputIdx, list);
|
|
|
|
/*
|
|
|
|
auto varSpace = block.getVariableSpace();
|
|
|
|
if (varSpace->hasVariable(block.getNodeId(), outputIdx)) {
|
|
|
|
auto var = varSpace->getVariable(block.getNodeId(), outputIdx);
|
|
|
|
var->setNDArrayList(list);
|
|
|
|
} else {
|
|
|
|
auto var = new Variable(nullptr, nullptr, block.getNodeId(), outputIdx);
|
|
|
|
var->setNDArrayList(list);
|
|
|
|
varSpace->putVariable(block.getNodeId(), outputIdx, var);
|
|
|
|
}
|
|
|
|
*/
|
|
|
|
}
|
|
|
|
|
|
|
|
Nd4jStatus nd4j::ops::DeclarableOp::validateArguments(Context& block) {
|
|
|
|
/*
|
|
|
|
* We're checking number of T and I arguments. If number of args is finite number - we check strict equality
|
|
|
|
* If number of args is variable (-1), but variables MUST be present - we check for non-zero number of arguments
|
|
|
|
*/
|
|
|
|
if (_descriptor->getNumberOfTArgs() > 0) {
|
|
|
|
if ((int) block.getTArguments()->size() < _descriptor->getNumberOfTArgs()) {
|
|
|
|
nd4j_printf("%s: %i T args expected, but %i received\n", this->getOpName()->c_str(), _descriptor->getNumberOfTArgs(), block.getTArguments()->size());
|
|
|
|
return ND4J_STATUS_BAD_PARAMS;
|
|
|
|
}
|
|
|
|
} else
|
|
|
|
if (_descriptor->getNumberOfTArgs() == -1)
|
|
|
|
if (block.getTArguments()->size() == 0) {
|
|
|
|
nd4j_printf("%s: Number of T arguments should be positive number, but got 0 arguments\n", this->getOpName()->c_str());
|
|
|
|
return ND4J_STATUS_BAD_PARAMS;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (_descriptor->getNumberOfIArgs() > 0) {
|
|
|
|
if ((int) block.getIArguments()->size() < _descriptor->getNumberOfIArgs()) {
|
|
|
|
nd4j_printf("%s: %i int args expected, but %i received\n", this->getOpName()->c_str(), _descriptor->getNumberOfIArgs(), block.getIArguments()->size());
|
|
|
|
return ND4J_STATUS_BAD_PARAMS;
|
|
|
|
}
|
|
|
|
} else
|
|
|
|
if (_descriptor->getNumberOfIArgs() == -1)
|
|
|
|
if (block.getIArguments()->size() == 0) {
|
|
|
|
nd4j_printf("%s: Number of Integer arguments should be positive number, but got 0 arguments\n", this->getOpName()->c_str());
|
|
|
|
return ND4J_STATUS_BAD_PARAMS;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return ND4J_STATUS_OK;
|
|
|
|
}
|
|
|
|
|
|
|
|
Nd4jStatus nd4j::ops::DeclarableOp::validateInputDimensions(Context& block, int rank) {
|
|
|
|
if (block.width() == 0)
|
|
|
|
return ND4J_STATUS_OK;
|
|
|
|
|
|
|
|
for (auto p: *block.inputs()) {
|
|
|
|
auto v = block.variable(p);
|
|
|
|
NDArray *aV = v->getNDArray();
|
|
|
|
|
|
|
|
if (aV == nullptr)
|
|
|
|
return ND4J_STATUS_BAD_INPUT;
|
|
|
|
|
|
|
|
if (aV->rankOf() != rank)
|
|
|
|
return ND4J_STATUS_BAD_DIMENSIONS;
|
|
|
|
}
|
|
|
|
|
|
|
|
return ND4J_STATUS_OK;
|
|
|
|
}
|
|
|
|
|
|
|
|
Nd4jStatus nd4j::ops::DeclarableOp::validateInput2D(Context& block) {
|
|
|
|
return validateInputDimensions(block, 2);
|
|
|
|
}
|
|
|
|
|
|
|
|
Nd4jStatus nd4j::ops::DeclarableOp::validateInput3D(Context& block) {
|
|
|
|
return validateInputDimensions(block, 3);
|
|
|
|
}
|
|
|
|
|
|
|
|
Nd4jStatus nd4j::ops::DeclarableOp::validateInput4D(Context& block) {
|
|
|
|
return validateInputDimensions(block, 4);
|
|
|
|
}
|
|
|
|
|
|
|
|
Nd4jStatus nd4j::ops::DeclarableOp::validateNonEmptyInput(Context& block) {
|
|
|
|
if (this->getOpDescriptor()->getNumberOfInputs() == -2 || this->getOpDescriptor()->getNumberOfInputs() == 0)
|
|
|
|
return Status::OK();
|
|
|
|
|
|
|
|
if (block.width() < 1) {
|
|
|
|
nd4j_printf("%s: no operands provided for the op", this->getOpName()->c_str());
|
|
|
|
return ND4J_STATUS_BAD_INPUT;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
int cnt = 0;
|
|
|
|
for (auto p: *block.inputs()) {
|
|
|
|
auto v = block.variable(p);
|
|
|
|
if (v == nullptr) {
|
|
|
|
if (this->getOpName() != nullptr) {
|
|
|
|
nd4j_printf("Node [%i:<%s>]: Variable [%i] (%i:%i) is NULL\n", block.getNodeId(), this->getOpName()->c_str(), cnt, p.first, p.second);
|
|
|
|
} else {
|
|
|
|
nd4j_printf("Node [%i:<noname>]: Variable [%i] (%i:%i) is NULL\n", block.getNodeId(), cnt, p.first, p.second);
|
|
|
|
}
|
|
|
|
return ND4J_STATUS_BAD_INPUT;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (v->variableType() == VariableType::NDARRAY) {
|
|
|
|
NDArray *aV = v->getNDArray();
|
|
|
|
|
|
|
|
// if array is empty intentionally - we're ok with that
|
|
|
|
if (v->hasNDArray() && v->isEmpty())
|
|
|
|
continue;
|
|
|
|
|
|
|
|
if (aV == nullptr || !aV->nonNull()) {
|
|
|
|
if (this->getOpName() != nullptr) {
|
|
|
|
nd4j_printf("Node [%i:<%s>]: NDArray [%i] (%i:%i) is NULL\n", block.getNodeId(), this->getOpName()->c_str(), cnt, p.first, p.second);
|
|
|
|
} else {
|
|
|
|
nd4j_printf("Node [%i:<noname>]: NDArray [%i] (%i:%i) is NULL\n", block.getNodeId(), cnt, p.first, p.second);
|
|
|
|
}
|
|
|
|
return ND4J_STATUS_BAD_INPUT;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
cnt++;
|
|
|
|
}
|
|
|
|
|
|
|
|
return ND4J_STATUS_OK;
|
|
|
|
}
|
|
|
|
|
|
|
|
Nd4jStatus nd4j::ops::DeclarableOp::validateOrdersMatch(Context& block) {
|
|
|
|
if (block.width() == 0)
|
|
|
|
return ND4J_STATUS_OK;
|
|
|
|
|
|
|
|
NDArray *a0 = block.variable(0)->getNDArray();
|
|
|
|
for (auto p: *block.inputs()) {
|
|
|
|
auto v = block.variable(p);
|
|
|
|
NDArray *aV = v->getNDArray();
|
|
|
|
if (a0->ordering() != aV->ordering())
|
|
|
|
return ND4J_STATUS_BAD_ORDER;
|
|
|
|
}
|
|
|
|
|
|
|
|
return ND4J_STATUS_OK;
|
|
|
|
}
|
|
|
|
|
2020-01-30 08:07:24 +01:00
|
|
|
Nd4jStatus nd4j::ops::DeclarableOp::execute(nd4j::graph::RandomGenerator& rng, const std::vector<NDArray*>& inputs, const std::vector<NDArray*>& outputs, const std::vector<double>& tArgs, const std::vector<Nd4jLong>& iArgs, const std::vector<bool>& bArgs, const std::vector<nd4j::DataType>& dArgs, bool isInplace, nd4j::DataType type) {
|
2019-06-06 14:21:15 +02:00
|
|
|
VariableSpace variableSpace;
|
|
|
|
FlowPath fp;
|
|
|
|
variableSpace.setFlowPath(&fp);
|
|
|
|
|
|
|
|
int cnt = -1;
|
|
|
|
std::vector<int> in;
|
|
|
|
for (auto v: inputs) {
|
|
|
|
if (v == nullptr)
|
|
|
|
continue;
|
|
|
|
|
|
|
|
auto var = new Variable(v);
|
|
|
|
var->markRemovable(false);
|
|
|
|
in.push_back(cnt);
|
|
|
|
variableSpace.putVariable(cnt--, var);
|
|
|
|
}
|
|
|
|
|
|
|
|
int et = 0;
|
|
|
|
for (auto v: outputs) {
|
|
|
|
auto var = new Variable(v);
|
|
|
|
var->markRemovable(false);
|
|
|
|
std::pair<int,int> pair(1, et++);
|
|
|
|
variableSpace.putVariable(pair, var);
|
|
|
|
}
|
|
|
|
|
|
|
|
Context block(1, &variableSpace, false);
|
|
|
|
block.fillInputs(in);
|
|
|
|
block.markInplace(isInplace);
|
|
|
|
block.setDataType(0, type);
|
|
|
|
|
|
|
|
// we need this line for tests basically
|
|
|
|
//if (rng != nullptr)
|
|
|
|
block.setRng(rng);
|
|
|
|
|
|
|
|
for (int e = 0; e < tArgs.size(); e++)
|
|
|
|
block.getTArguments()->emplace_back(tArgs.at(e));
|
|
|
|
|
|
|
|
// FIXME: iargs should be Nd4jLong
|
|
|
|
for (int e = 0; e < iArgs.size(); e++)
|
|
|
|
block.getIArguments()->emplace_back(static_cast<int>(iArgs.at(e)));
|
|
|
|
|
|
|
|
for (int e = 0; e < bArgs.size(); e++)
|
|
|
|
block.getBArguments()->push_back(static_cast<int>(bArgs.at(e)));
|
|
|
|
|
2020-01-30 08:07:24 +01:00
|
|
|
for (int e = 0; e < dArgs.size(); e++)
|
|
|
|
block.getDArguments()->push_back(dArgs.at(e));
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
Nd4jStatus result = this->execute(&block);
|
|
|
|
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
2020-01-30 08:07:24 +01:00
|
|
|
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs) {
|
|
|
|
return execute(inputs, outputs, std::vector<double>(), std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<double> tArgs) {
|
2020-01-31 08:45:41 +01:00
|
|
|
return execute(inputs, outputs, tArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<nd4j::DataType> dArgs) {
|
|
|
|
return execute(inputs, outputs, std::vector<double>(), std::vector<Nd4jLong>(), std::vector<bool>(), dArgs);
|
2020-01-30 08:07:24 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<float> tArgs) {
|
|
|
|
std::vector<double> realArgs;
|
|
|
|
for (auto v:tArgs)
|
|
|
|
realArgs.emplace_back(v);
|
|
|
|
|
2020-01-31 08:45:41 +01:00
|
|
|
return execute(inputs, outputs, realArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());
|
2020-01-30 08:07:24 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<Nd4jLong> iArgs) {
|
2020-01-31 08:45:41 +01:00
|
|
|
return execute(inputs, outputs, std::vector<double>(), iArgs, std::vector<bool>(), std::vector<nd4j::DataType>());
|
2020-01-30 08:07:24 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<int> iArgs) {
|
|
|
|
std::vector<Nd4jLong> realArgs;
|
|
|
|
for (auto v:iArgs)
|
|
|
|
realArgs.emplace_back(v);
|
|
|
|
|
2020-01-31 08:45:41 +01:00
|
|
|
return execute(inputs, outputs, std::vector<double>(), realArgs, std::vector<bool>(), std::vector<nd4j::DataType>());
|
2020-01-30 08:07:24 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<bool> bArgs) {
|
2020-01-31 08:45:41 +01:00
|
|
|
return execute(inputs, outputs, std::vector<double>(), std::vector<Nd4jLong>(), bArgs, std::vector<nd4j::DataType>());
|
2020-01-30 08:07:24 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs, const std::vector<nd4j::DataType> &dArgs, bool isInplace) {
|
|
|
|
Context ctx(1);
|
|
|
|
|
|
|
|
for (int e = 0; e < inputs.size(); e++) {
|
|
|
|
if (inputs[e] == nullptr)
|
|
|
|
break;
|
|
|
|
|
|
|
|
ctx.setInputArray(e, inputs[e]);
|
|
|
|
}
|
|
|
|
|
|
|
|
for (int e = 0; e < outputs.size(); e++) {
|
|
|
|
if (outputs[e] == nullptr)
|
|
|
|
break;
|
|
|
|
|
|
|
|
ctx.setOutputArray(e, outputs[e]);
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if (isInplace)
|
|
|
|
ctx.markInplace(isInplace);
|
|
|
|
|
|
|
|
ctx.setIArguments(iArgs);
|
|
|
|
ctx.setTArguments(tArgs);
|
|
|
|
ctx.setBArguments(bArgs);
|
|
|
|
ctx.setDArguments(dArgs);
|
|
|
|
|
|
|
|
return execute(&ctx);
|
|
|
|
}
|
|
|
|
|
|
|
|
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs) {
|
|
|
|
return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<int> iArgs) {
|
|
|
|
std::vector<Nd4jLong> realArgs;
|
|
|
|
for (auto v:iArgs)
|
|
|
|
realArgs.emplace_back(v);
|
|
|
|
|
2020-01-31 08:45:41 +01:00
|
|
|
return evaluate(inputs, std::vector<double>(), realArgs, std::vector<bool>(), std::vector<nd4j::DataType>());
|
2020-01-30 08:07:24 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<Nd4jLong> iArgs) {
|
2020-01-31 08:45:41 +01:00
|
|
|
return evaluate(inputs, std::vector<double>(), iArgs, std::vector<bool>(), std::vector<nd4j::DataType>());
|
2020-01-30 08:07:24 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<float> tArgs) {
|
|
|
|
std::vector<double> realArgs;
|
|
|
|
for (auto v:tArgs)
|
|
|
|
realArgs.emplace_back(v);
|
|
|
|
|
2020-01-31 08:45:41 +01:00
|
|
|
return evaluate(inputs, realArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());
|
2020-01-30 08:07:24 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<double> tArgs) {
|
2020-01-31 08:45:41 +01:00
|
|
|
return evaluate(inputs, tArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());
|
2020-01-30 08:07:24 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<bool> bArgs) {
|
2020-01-31 08:45:41 +01:00
|
|
|
return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), bArgs, std::vector<nd4j::DataType>());
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<nd4j::DataType> bArgs) {
|
|
|
|
return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), std::vector<bool>(), bArgs);
|
2020-01-30 08:07:24 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs, const std::vector<nd4j::DataType> &dArgs, bool isInplace) {
|
2019-06-06 14:21:15 +02:00
|
|
|
VariableSpace variableSpace;
|
|
|
|
//ResultSet arrayList;
|
|
|
|
FlowPath fp;
|
|
|
|
variableSpace.setFlowPath(&fp);
|
|
|
|
|
|
|
|
int cnt = -1;
|
|
|
|
std::vector<int> in;
|
|
|
|
for (auto v: inputs) {
|
|
|
|
if (v == nullptr)
|
|
|
|
continue;
|
|
|
|
|
|
|
|
auto var = new Variable(v);
|
|
|
|
var->markRemovable(false);
|
|
|
|
in.push_back(cnt);
|
|
|
|
variableSpace.putVariable(cnt--, var);
|
|
|
|
}
|
|
|
|
|
|
|
|
Context block(1, &variableSpace, false);
|
2020-01-30 08:07:24 +01:00
|
|
|
block.setDataType(0, nd4j::DataType::FLOAT32);
|
2019-06-06 14:21:15 +02:00
|
|
|
block.fillInputs(in);
|
|
|
|
block.markInplace(isInplace);
|
2020-01-30 08:07:24 +01:00
|
|
|
// block.setRNG(ProviderRNG::getInstance().getRNG());
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
for (int e = 0; e < tArgs.size(); e++)
|
|
|
|
block.getTArguments()->emplace_back(tArgs.at(e));
|
|
|
|
|
|
|
|
for (int e = 0; e < iArgs.size(); e++)
|
|
|
|
block.getIArguments()->emplace_back(iArgs.at(e));
|
|
|
|
|
|
|
|
for (int e = 0; e < bArgs.size(); e++)
|
|
|
|
block.getBArguments()->push_back(bArgs.at(e));
|
|
|
|
|
2020-01-30 08:07:24 +01:00
|
|
|
for (int e = 0; e < dArgs.size(); e++)
|
|
|
|
block.getDArguments()->push_back(dArgs.at(e));
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
Nd4jStatus status = this->execute(&block);
|
|
|
|
auto arrayList = new ResultSet();
|
|
|
|
if (isInplace)
|
|
|
|
arrayList->setNonRemovable();
|
|
|
|
|
|
|
|
arrayList->setStatus(status);
|
|
|
|
if (status != ND4J_STATUS_OK)
|
|
|
|
return arrayList;
|
|
|
|
|
|
|
|
|
|
|
|
for (int e = 0; e < DataTypeUtils::max<int>(); e++) {
|
|
|
|
std::pair<int,int> pair(1, e);
|
|
|
|
if (variableSpace.hasVariable(pair)) {
|
|
|
|
auto var = variableSpace.getVariable(pair);
|
|
|
|
auto arr = var->getNDArray();
|
|
|
|
if (!arr->isAttached()) {
|
|
|
|
var->markRemovable(false);
|
|
|
|
arr->setContext(nd4j::LaunchContext ::defaultContext());
|
|
|
|
arrayList->push_back(arr);
|
|
|
|
} else {
|
|
|
|
arrayList->push_back(arr->detach());
|
|
|
|
}
|
|
|
|
} else
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
|
|
|
|
return arrayList;
|
|
|
|
}
|
|
|
|
|
|
|
|
nd4j::ResultSet* nd4j::ops::DeclarableOp::execute(const nd4j::OpArgsHolder& holder, bool isInplace) {
|
2020-01-30 08:07:24 +01:00
|
|
|
// FIXME: add DArgs to OpArgsHolder
|
|
|
|
return evaluate(holder.getInArrs(), holder.getTArgs(), holder.getIArgs(), holder.getBArgs(), std::vector<nd4j::DataType>(), isInplace);
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
Nd4jStatus nd4j::ops::DeclarableOp::validateInputDimensionsMatch(Context& block) {
|
|
|
|
if (block.width() == 0)
|
|
|
|
return ND4J_STATUS_OK;
|
|
|
|
|
|
|
|
NDArray *a0 = block.array(0);
|
|
|
|
for (int e = 0; e < block.width(); e++) {
|
|
|
|
auto aV = block.array(e);
|
|
|
|
if (!shape::equalsSoft(a0->getShapeInfo(), aV->getShapeInfo()))
|
|
|
|
return ND4J_STATUS_BAD_DIMENSIONS;
|
|
|
|
}
|
|
|
|
|
|
|
|
return ND4J_STATUS_OK;
|
|
|
|
}
|
|
|
|
|
|
|
|
Nd4jStatus nd4j::ops::DeclarableOp::validateInputLengthMatch(Context& block) {
|
|
|
|
if (block.width() == 0)
|
|
|
|
return ND4J_STATUS_OK;
|
|
|
|
|
|
|
|
|
|
|
|
Nd4jLong l0 = block.array(0)->lengthOf();
|
|
|
|
for (uint32_t e = 0; e < block.width(); e++) {
|
|
|
|
if (l0 != block.array(e)->lengthOf())
|
|
|
|
return ND4J_STATUS_BAD_LENGTH;
|
|
|
|
}
|
|
|
|
|
|
|
|
return ND4J_STATUS_OK;
|
|
|
|
}
|
|
|
|
|
2019-11-21 11:31:20 +01:00
|
|
|
samediff::EmptyHandling DeclarableOp::emptyHandling() {
|
|
|
|
return samediff::EmptyHandling::EMPTY_SKIP;
|
|
|
|
}
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
void DeclarableOp::registerTypes() {
|
|
|
|
this->getOpDescriptor()->setSameMode(true);
|
|
|
|
}
|
|
|
|
|
|
|
|
/*
|
|
|
|
template <typename T>
|
|
|
|
int* nd4j::ops::DeclarableOp::calculateOutputShape(int* inputShape, nd4j::graph::Block& block) {
|
|
|
|
// default implementation suits transform, so just returns the same shape
|
|
|
|
|
|
|
|
int* newshape;
|
|
|
|
ALLOCATE(newshape, block.getWorkspace(), shape::shapeInfoLength(inputShape), int);
|
|
|
|
memcpy(newshape, inputShape, shape::shapeInfoByteLength(inputShape));
|
|
|
|
|
|
|
|
return newshape;
|
|
|
|
}
|
|
|
|
*/
|
|
|
|
}
|
|
|
|
}
|