cavis/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp

929 lines
41 KiB
C++
Raw Normal View History

2019-06-06 14:21:15 +02:00
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by raver119 on 07.10.2017.
//
#include <ops/declarable/DeclarableOp.h>
#include <Status.h>
#include <helpers/ShapeUtils.h>
#include <NDArrayFactory.h>
#include <exceptions/graph_exception.h>
#include <exceptions/unresolved_input_exception.h>
Platform helpers (#8216) * platform helpers draft Signed-off-by: raver119 <raver119@gmail.com> * typo Signed-off-by: raver119 <raver119@gmail.com> * disable platform cmake Signed-off-by: raver119 <raver119@gmail.com> * another draft Signed-off-by: raver119 <raver119@gmail.com> * mkldnn convolution refactored Signed-off-by: raver119 <raver119@gmail.com> * minor tweaks Signed-off-by: raver119 <raver119@gmail.com> * one more safety check Signed-off-by: raver119 <raver119@gmail.com> * prototype works Signed-off-by: raver119 <raver119@gmail.com> * meh Signed-off-by: raver119 <raver119@gmail.com> * force static library mode for mkldnn Signed-off-by: raver119 <raver119@gmail.com> * - ismax fix - experimental arg fix - don't enforce openblas on Apple hardware Signed-off-by: raver119 <raver119@gmail.com> * bunch of small fixes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * declare concurrent Signed-off-by: raver119@gmail.com <raver119@gmail.com> * - MKLDNN version upgrade to 1.0.2 - avgpool2d/maxpool2d APIs update Signed-off-by: raver119 <raver119@gmail.com> * - avgpool2d_bp/maxpool2d_bp APIs update Signed-off-by: raver119 <raver119@gmail.com> * - conv2d/batchnorm APIs update Signed-off-by: raver119 <raver119@gmail.com> * - lrn/conv2d_bp/conv3d/conv3d_bp APIs update Signed-off-by: raver119 <raver119@gmail.com> * all ops converted to MKLDNN 1.x Signed-off-by: raver119 <raver119@gmail.com> * bunch of tweaks Signed-off-by: raver119 <raver119@gmail.com> * namespace for platform helpers Signed-off-by: raver119 <raver119@gmail.com> * make sure platform helpers aren't opimized out Signed-off-by: raver119 <raver119@gmail.com> * build cpu_features on x86 systems Signed-off-by: raver119 <raver119@gmail.com> * build cpu_features on x86 systems Signed-off-by: raver119 <raver119@gmail.com> * more of cpu_features Signed-off-by: raver119 <raver119@gmail.com> * - mkldnn removed from java - cpu_features checks in CpuNDArrayFactory Signed-off-by: raver119 <raver119@gmail.com> * F16C definition renamed Signed-off-by: raver119 <raver119@gmail.com> * some mkldnn rearrangements Signed-off-by: raver119 <raver119@gmail.com> * check supported instructions before doing anything Signed-off-by: raver119 <raver119@gmail.com> * typo Signed-off-by: raver119 <raver119@gmail.com> * missied impl Signed-off-by: raver119 <raver119@gmail.com> * BUILD_PIC option Signed-off-by: raver119 <raver119@gmail.com> * conv2d fix Signed-off-by: raver119 <raver119@gmail.com> * avgpool3d fix Signed-off-by: raver119 <raver119@gmail.com> * avgpool3d_bp fix Signed-off-by: raver119 <raver119@gmail.com> * avgpool2d_bp leak fix Signed-off-by: raver119 <raver119@gmail.com> * avgpool3d_bp leak fix Signed-off-by: raver119 <raver119@gmail.com> * maxpool bp leaks fixed Signed-off-by: raver119 <raver119@gmail.com> * printf removed Signed-off-by: raver119 <raver119@gmail.com> * batchnorm fix Signed-off-by: raver119 <raver119@gmail.com> * AVX warning/error polishing Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * More polish Signed-off-by: AlexDBlack <blacka101@gmail.com> * Polish Signed-off-by: AlexDBlack <blacka101@gmail.com> * remove previous MKL-DNN support layer Signed-off-by: raver119 <raver119@gmail.com> * avx2 tweak Signed-off-by: raver119 <raver119@gmail.com> * allow static for apple Signed-off-by: raver119@gmail.com <raver119@gmail.com> * exclude mkldnn in one more place Signed-off-by: raver119 <raver119@gmail.com> * exclude mkldnn in one more place Signed-off-by: raver119 <raver119@gmail.com> * restore OPENBLAS_PATH use Signed-off-by: raver119 <raver119@gmail.com> * add runtime check for avx/avx2 support Signed-off-by: raver119 <raver119@gmail.com> * convolution_auto Signed-off-by: raver119 <raver119@gmail.com> * Add logic for helper argument * minor test fix Signed-off-by: raver119 <raver119@gmail.com> * few tweaks Signed-off-by: raver119 <raver119@gmail.com> * few tweaks Signed-off-by: raver119 <raver119@gmail.com> * skip OpTracker props for non-x86 builds Signed-off-by: raver119 <raver119@gmail.com> * linux arm isn't x86 :) Signed-off-by: raver119 <raver119@gmail.com> * avx-512 Signed-off-by: raver119 <raver119@gmail.com> * CUDA presets fix Signed-off-by: raver119 <raver119@gmail.com> * BUILD_PIC Signed-off-by: raver119 <raver119@gmail.com> * prefetchw for avx2 Signed-off-by: raver119 <raver119@gmail.com> * BUILD_PIC again Signed-off-by: raver119 <raver119@gmail.com>
2019-09-11 20:50:28 +02:00
#include <ops/declarable/OpRegistrator.h>
2019-06-06 14:21:15 +02:00
namespace nd4j {
namespace ops {
Nd4jStatus conditionHelper(const char *file, int line, int condition, int argNumber, const char *format, ...) {
if (!condition) {
va_list args;
printf("Error at [%s:%i:%i]:\n", file, line, argNumber);
va_start(args, format);
vprintf(format, args);
va_end(args);
printf("\n");
fflush(stdout);
return ND4J_STATUS_BAD_PARAMS;
}
return ND4J_STATUS_OK;
}
DeclarableOp::DeclarableOp() {
// no-op
}
DeclarableOp::DeclarableOp(const char *name, bool isLogical) {
_descriptor = new OpDescriptor(name, isLogical);
}
DeclarableOp::DeclarableOp(const char *name, int numInputs, bool scalar) {
_descriptor = new OpDescriptor(numInputs, name, scalar);
}
DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace) {
_descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace);
}
DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, bool divergent) {
_descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace, divergent);
}
DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs) {
_descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace, tArgs, iArgs);
}
DeclarableOp::~DeclarableOp() {
if (_descriptor != nullptr)
delete _descriptor;
if (_scalar != nullptr)
delete _scalar;
}
OpDescriptor* DeclarableOp::getOpDescriptor() {
return _descriptor;
}
std::string *DeclarableOp::getOpName() {
return _descriptor->getOpName();
}
Nd4jLong DeclarableOp::getOpHash() {
return _descriptor->getHash();
}
nd4j::NDArray* nd4j::ops::DeclarableOp::getZ(Context& ctx, int inputId) {
NDArray* z = nullptr;
if (ctx.isFastPath()) {
if (ctx.fastpath_out().size() <= inputId) {
if (ctx.isInplace()) {
z = ctx.fastpath_in()[inputId];
} else
throw std::runtime_error("fastpath_out: unresolved output array");
} else {
z = ctx.fastpath_out()[inputId];
}
} else {
std::pair<int, int> pair(ctx.nodeId(), inputId);
if (ctx.isInplace()) {
z = ctx.variable(inputId)->getNDArray();
// hypothetically it's possible to have no variable. chances are low, but who knows. let's just create it for now
if (!ctx.getVariableSpace()->hasVariable(pair)) {
auto var = new Variable();
ctx.getVariableSpace()->putVariable(pair, var);
}
// now we're saving input array as output array
auto var = ctx.getVariableSpace()->getVariable(pair);
var->markRemovable(false);
var->setNDArray(z);
} else if (!ctx.isInplace()) {
auto var = ctx.variable(pair);
if (var->getNDArray() != nullptr && var->getNDArray()->nonNull()) {
z = var->getNDArray();
} else {
nd4j_printf("Can't get Z variable for node_%i!\n", ctx.nodeId());
}
} else {
nd4j_printf("BOOM!\n", "");
throw std::runtime_error("Boom!");
}
}
return z;
}
int nd4j::ops::DeclarableOp::prepareOutputs(Context &ctx) {
auto workspace = ctx.getWorkspace();
GraphProfile *prof = nullptr;
NodeProfile *node = nullptr;
std::chrono::time_point<std::chrono::system_clock> inputEnd, inputStart, shapeStart, shapeEnd, arrayStart, arrayEnd;
if (Environment::getInstance()->isProfiling()) {
if (ctx.getVariableSpace() != nullptr && ctx.getVariableSpace()->flowPath() != nullptr) {
prof = ctx.getVariableSpace()->flowPath()->profile();
node = prof->nodeById(ctx.nodeId());
}
}
if (ctx.isInplace()) {
// do nothing, getZ result will do the trick
return static_cast<int>(ctx.width());
} else {
// if op is not inplace - we should pre-allocate arrays
ShapeList inSha;
int results = 0;
if (Environment::getInstance()->isProfiling() && node != nullptr)
inputStart = std::chrono::system_clock::now();
int cntIn = 0;
// we build list of input shapes
if (ctx.isFastPath()) {
for (const auto p:ctx.fastpath_in()) {
inSha.push_back(p->getShapeInfo());
}
} else {
for (auto p: *ctx.inputs()) {
auto var = ctx.variable(p);
if (var->variableType() == VariableType::NDARRAY) {
NDArray *array = var->getNDArray();
if (array == nullptr)
throw unresolved_input_exception::build("Variable wasn't resolved prior shape calculation", p);
inSha.push_back(array->getShapeInfo());
}
cntIn++;
}
}
// optionally saving input time
if (Environment::getInstance()->isProfiling() && node != nullptr) {
inputEnd = std::chrono::system_clock::now();
auto inputTime = std::chrono::duration_cast<std::chrono::nanoseconds>(inputEnd - inputStart).count();
node->setInputTime(inputTime);
shapeStart = std::chrono::system_clock::now();
}
auto outSha = this->calculateOutputShape(&inSha, ctx);
results = outSha->size();
// optionally saving shapeTime
if (Environment::getInstance()->isProfiling() && node != nullptr) {
shapeEnd = std::chrono::system_clock::now();
auto prepTime = std::chrono::duration_cast<std::chrono::nanoseconds>(shapeEnd - shapeStart).count();
node->setShapeFunctionTime(prepTime);
arrayStart = std::chrono::system_clock::now();
}
int cnt = 0;
for (auto out: *outSha->asVector()) {
if (!ctx.isFastPath()) {
// we need to check, if Z is really needed
std::pair<int, int> pair(ctx.nodeId(), cnt++);
if (!ctx.isValueAvailable(pair.second)) {
if (Environment::getInstance()->isDebugAndVerbose())
shape::printShapeInfoLinear("Going to create variable with shape", out);
auto outArr = new NDArray(out, true, ctx.launchContext());
ctx.pushNDArrayToVariableSpace(pair, outArr);
} else {
// validate/compare shapes here. existent vs provided in outSha
auto var = ctx.variable(pair);
auto shape = var->getNDArray()->shapeInfo();
if (!shape::equalsSoft(out, shape)) {
auto eShape = ShapeUtils::shapeAsString(out);
auto aShape = ShapeUtils::shapeAsString(shape);
//outSha->destroy();
delete outSha;
nd4j_printf("Expected vs provided shapes mismatch %s vs %s at index %i\n", eShape.c_str(), aShape.c_str(), pair.second);
throw std::runtime_error("Expected vs provided shapes mismatch");
}
}
} else {
auto fout = ctx.fastpath_out();
auto idx = cnt++;
if (fout.size() <= idx) {
// array doesnt exist
auto outArr = new NDArray(out, true, ctx.launchContext());
ctx.setOutputArray(idx, outArr, true);
} else {
auto array = fout[idx];
if (!shape::equalsSoft(out, array->shapeInfo())) {
auto eShape = ShapeUtils::shapeAsString(out);
auto aShape = ShapeUtils::shapeAsString(array->shapeInfo());
//outSha->destroy();
delete outSha;
nd4j_printf("Expected vs provided shape mismatch %s vs %s at index %i\n", eShape.c_str(), aShape.c_str(), idx);
throw std::runtime_error("Expected vs provided shape mismatch");
}
}
}
}
//outSha->destroy();
delete outSha;
// saving arrayTime
if (Environment::getInstance()->isProfiling() && node != nullptr) {
arrayEnd = std::chrono::system_clock::now();
auto arrayTime = std::chrono::duration_cast<std::chrono::nanoseconds>(arrayEnd - arrayStart).count();
node->setArrayTime(arrayTime);
}
return results;
}
}
void nd4j::ops::DeclarableOp::storeResult(Context &block, int outputNumber, NDArray* array) {
this->storeResult(block, outputNumber, *array);
}
void nd4j::ops::DeclarableOp::storeResult(nd4j::graph::Context &ctx, int outputNumber, NDArray& array) {
ctx.pushNDArrayToVariableSpace(ctx.nodeId(), outputNumber, &array, !ctx.isInplace());
}
bool nd4j::ops::DeclarableOp::allocateResult(Context& block, Nd4jLong* shape) {
auto var = block.variable(block.getNodeId(), 0);
auto workspace = block.getWorkspace();
Nd4jLong len = shape::length(shape);
Nd4jLong* __shape;
ALLOCATE(__shape, workspace, shape::shapeInfoLength(shape), Nd4jLong); //new int[shape[0] * 2 + 4];
memcpy(__shape, shape, shape::shapeInfoByteLength(shape));
// if that's first run - we probably have nothing here
if (var->getNDArray() == nullptr) {
std::shared_ptr<DataBuffer> buffer = std::make_shared<DataBuffer>(len * sizeof(int8_t), ArrayOptions::dataType(__shape), workspace);
var->setNDArray(new NDArray(buffer, ShapeDescriptor(__shape), block.launchContext()));
}
else if(var->getNDArray()->lengthOf() != len) {
// if length not match - lets reallocate array
delete var->getNDArray();
std::shared_ptr<DataBuffer> buffer = std::make_shared<DataBuffer>(len * sizeof(int8_t), ArrayOptions::dataType(__shape), workspace);
var->setNDArray(new NDArray(buffer, ShapeDescriptor(__shape), block.launchContext()));
}
return true;
}
bool nd4j::ops::DeclarableOp::allocateResult(Context& block, std::initializer_list<Nd4jLong>& shape, char order) {
auto var = block.variable(block.getNodeId(), 0);
auto workspace = block.getWorkspace();
Nd4jLong len = shape::length(shape);
// if that's first run - we probably have nothing here
if (var->getNDArray() == nullptr) {
var->setNDArray(new NDArray(order, shape, block.dataType(), block.launchContext()));
} else if(var->getNDArray()->lengthOf() != len) {
// if length not match - lets reallocate array
delete var->getNDArray();
var->setNDArray(new NDArray(order, shape, block.dataType(), block.launchContext()));
}
return true;
}
Nd4jStatus nd4j::ops::DeclarableOp::validateDataTypes(Context& block) {
_registrator.lock();
if (!_registered) {
_registered = true;
this->registerTypes();
}
_registrator.unlock();
// rolling over inputs first
int cnt = 0, inT = 0;
std::vector<nd4j::DataType> inputTypes(block.width());
if (block.isFastPath()) {
for (auto array: block.fastpath_in()) {
2019-06-06 14:21:15 +02:00
inputTypes[inT++] = array->dataType();
if (!_descriptor->checkInputMatch(cnt, array->dataType())) {
auto ctype = DataTypeUtils::asString(array->dataType());
nd4j_printf("Op [%s] failed check for input [%i], DataType: [%s]\n",
_descriptor->getOpName()->data(), cnt, ctype.c_str());
2019-06-06 14:21:15 +02:00
return ND4J_STATUS_BAD_ARGUMENTS;
}
cnt++;
2019-06-06 14:21:15 +02:00
}
} else {
for (auto &p: *(block.inputs())) {
auto var = block.variable(p);
2019-06-06 14:21:15 +02:00
// we're not checking validity, if ANY types were explicitly allowed
//if (block.dataType(cnt) == nd4j::DataType::ANY)
// continue;
2019-06-06 14:21:15 +02:00
// only validating non-null variables
if (var != nullptr && var->hasNDArray()) {
auto array = var->getNDArray();
inputTypes[inT++] = array->dataType();
if (!_descriptor->checkInputMatch(cnt, array->dataType())) {
auto ctype = DataTypeUtils::asString(array->dataType());
nd4j_printf("Op [%s] failed check for input [%i], DataType: [%s]\n",
_descriptor->getOpName()->data(), cnt, ctype.c_str());
return ND4J_STATUS_BAD_ARGUMENTS;
}
}
2019-06-06 14:21:15 +02:00
cnt++;
}
}
2019-06-06 14:21:15 +02:00
if (block.isFastPath()) {
int index = 0;
for (auto array: block.fastpath_out()) {
auto cType = array->dataType();
2019-06-06 14:21:15 +02:00
if (_descriptor->isSameMode()) {
if (index >= block.width()) {
if (block.fastpath_in().size() == 0)
continue;
auto ia = block.fastpath_in()[0];
if (ia->dataType() != cType) {
2019-06-06 14:21:15 +02:00
auto t = DataTypeUtils::asString(cType);
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n",
_descriptor->getOpName()->data(), index, t.c_str());
2019-06-06 14:21:15 +02:00
return ND4J_STATUS_BAD_ARGUMENTS;
}
} else {
// for same mode, output type must be the same as input type
auto ia = block.fastpath_in()[index];
2019-06-06 14:21:15 +02:00
if (ia->dataType() != cType) {
auto t = DataTypeUtils::asString(cType);
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n",
_descriptor->getOpName()->data(), index, t.c_str());
return ND4J_STATUS_BAD_ARGUMENTS;
}
}
} else if (_descriptor->isInherit(index)) {
// in inherit mode, output type must be the same as one of input types
if (std::find(inputTypes.begin(), inputTypes.end(), cType) == inputTypes.end()) {
2019-06-06 14:21:15 +02:00
auto t = DataTypeUtils::asString(cType);
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s].\n",
_descriptor->getOpName()->data(), index, t.c_str());
2019-06-06 14:21:15 +02:00
return ND4J_STATUS_BAD_ARGUMENTS;
}
} else if (!_descriptor->checkOutputMatch(index, cType)) {
auto t = DataTypeUtils::asString(cType);
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s];\n",
_descriptor->getOpName()->data(), index, t.c_str());
return ND4J_STATUS_BAD_ARGUMENTS;
2019-06-06 14:21:15 +02:00
}
index++;
}
} else {
// checking optionally available outputs
auto varSpace = block.getVariableSpace();
for (int index = 0; index < DataTypeUtils::max<int>(); index++) {
if (varSpace != nullptr && varSpace->hasVariable(block.nodeId(), index)) {
auto var = block.variable(block.nodeId(), index);
// only validating non-null variables
if (var != nullptr && var->hasNDArray()) {
auto array = var->getNDArray();
auto cType = array->dataType();
if (_descriptor->isSameMode()) {
if (index >= block.width()) {
if (block.width() == 0)
continue;
auto iv = block.variable(0);
if (iv->getNDArray()->dataType() != cType) {
auto t = DataTypeUtils::asString(cType);
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n",
_descriptor->getOpName()->data(), index, t.c_str());
return ND4J_STATUS_BAD_ARGUMENTS;
}
} else {
// for same mode, output type must be the same as input type
auto iv = block.variable(index);
if (iv->getNDArray()->dataType() != cType) {
auto t = DataTypeUtils::asString(cType);
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n",
_descriptor->getOpName()->data(), index, t.c_str());
return ND4J_STATUS_BAD_ARGUMENTS;
}
}
} else if (_descriptor->isInherit(index)) {
// in inherit mode, output type must be the same as one of input types
if (std::find(inputTypes.begin(), inputTypes.end(), cType) == inputTypes.end()) {
auto t = DataTypeUtils::asString(cType);
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s].\n",
_descriptor->getOpName()->data(), index, t.c_str());
return ND4J_STATUS_BAD_ARGUMENTS;
}
} else if (!_descriptor->checkOutputMatch(index, cType)) {
auto t = DataTypeUtils::asString(cType);
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s];\n",
_descriptor->getOpName()->data(), index, t.c_str());
return ND4J_STATUS_BAD_ARGUMENTS;
}
}
} else
break;
}
2019-06-06 14:21:15 +02:00
}
return ND4J_STATUS_OK;
}
Nd4jStatus nd4j::ops::DeclarableOp::execute(Context* block) {
nd4j_debug("Executing op: [%s]\n", this->getOpName()->c_str());
std::chrono::time_point<std::chrono::system_clock> timeEnter, timeStart, timeEnd;
Nd4jLong prepTime, outerTime;
Nd4jLong memoryBefore = block->workspace() == nullptr ? 0L : block->workspace()->getSpilledSize() + block->workspace()->getUsedSize();
if (Environment::getInstance()->isProfiling())
timeEnter = std::chrono::system_clock::now();
// basic validation: ensure inputs are set
REQUIRE_OK(this->validateNonEmptyInput(*block));
// ensure number of IArgs, TArgs match our expectations
REQUIRE_OK(this->validateArguments(*block));
// validating data types for inputs and (optionally) outputs
REQUIRE_OK(this->validateDataTypes(*block));
// this method will allocate output NDArrays for this op
auto numOutputs = this->prepareOutputs(*block);
if (Environment::getInstance()->isProfiling()) {
timeStart = std::chrono::system_clock::now();
prepTime = std::chrono::duration_cast<std::chrono::nanoseconds>(timeStart - timeEnter).count();
}
Platform helpers (#8216) * platform helpers draft Signed-off-by: raver119 <raver119@gmail.com> * typo Signed-off-by: raver119 <raver119@gmail.com> * disable platform cmake Signed-off-by: raver119 <raver119@gmail.com> * another draft Signed-off-by: raver119 <raver119@gmail.com> * mkldnn convolution refactored Signed-off-by: raver119 <raver119@gmail.com> * minor tweaks Signed-off-by: raver119 <raver119@gmail.com> * one more safety check Signed-off-by: raver119 <raver119@gmail.com> * prototype works Signed-off-by: raver119 <raver119@gmail.com> * meh Signed-off-by: raver119 <raver119@gmail.com> * force static library mode for mkldnn Signed-off-by: raver119 <raver119@gmail.com> * - ismax fix - experimental arg fix - don't enforce openblas on Apple hardware Signed-off-by: raver119 <raver119@gmail.com> * bunch of small fixes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * declare concurrent Signed-off-by: raver119@gmail.com <raver119@gmail.com> * - MKLDNN version upgrade to 1.0.2 - avgpool2d/maxpool2d APIs update Signed-off-by: raver119 <raver119@gmail.com> * - avgpool2d_bp/maxpool2d_bp APIs update Signed-off-by: raver119 <raver119@gmail.com> * - conv2d/batchnorm APIs update Signed-off-by: raver119 <raver119@gmail.com> * - lrn/conv2d_bp/conv3d/conv3d_bp APIs update Signed-off-by: raver119 <raver119@gmail.com> * all ops converted to MKLDNN 1.x Signed-off-by: raver119 <raver119@gmail.com> * bunch of tweaks Signed-off-by: raver119 <raver119@gmail.com> * namespace for platform helpers Signed-off-by: raver119 <raver119@gmail.com> * make sure platform helpers aren't opimized out Signed-off-by: raver119 <raver119@gmail.com> * build cpu_features on x86 systems Signed-off-by: raver119 <raver119@gmail.com> * build cpu_features on x86 systems Signed-off-by: raver119 <raver119@gmail.com> * more of cpu_features Signed-off-by: raver119 <raver119@gmail.com> * - mkldnn removed from java - cpu_features checks in CpuNDArrayFactory Signed-off-by: raver119 <raver119@gmail.com> * F16C definition renamed Signed-off-by: raver119 <raver119@gmail.com> * some mkldnn rearrangements Signed-off-by: raver119 <raver119@gmail.com> * check supported instructions before doing anything Signed-off-by: raver119 <raver119@gmail.com> * typo Signed-off-by: raver119 <raver119@gmail.com> * missied impl Signed-off-by: raver119 <raver119@gmail.com> * BUILD_PIC option Signed-off-by: raver119 <raver119@gmail.com> * conv2d fix Signed-off-by: raver119 <raver119@gmail.com> * avgpool3d fix Signed-off-by: raver119 <raver119@gmail.com> * avgpool3d_bp fix Signed-off-by: raver119 <raver119@gmail.com> * avgpool2d_bp leak fix Signed-off-by: raver119 <raver119@gmail.com> * avgpool3d_bp leak fix Signed-off-by: raver119 <raver119@gmail.com> * maxpool bp leaks fixed Signed-off-by: raver119 <raver119@gmail.com> * printf removed Signed-off-by: raver119 <raver119@gmail.com> * batchnorm fix Signed-off-by: raver119 <raver119@gmail.com> * AVX warning/error polishing Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * More polish Signed-off-by: AlexDBlack <blacka101@gmail.com> * Polish Signed-off-by: AlexDBlack <blacka101@gmail.com> * remove previous MKL-DNN support layer Signed-off-by: raver119 <raver119@gmail.com> * avx2 tweak Signed-off-by: raver119 <raver119@gmail.com> * allow static for apple Signed-off-by: raver119@gmail.com <raver119@gmail.com> * exclude mkldnn in one more place Signed-off-by: raver119 <raver119@gmail.com> * exclude mkldnn in one more place Signed-off-by: raver119 <raver119@gmail.com> * restore OPENBLAS_PATH use Signed-off-by: raver119 <raver119@gmail.com> * add runtime check for avx/avx2 support Signed-off-by: raver119 <raver119@gmail.com> * convolution_auto Signed-off-by: raver119 <raver119@gmail.com> * Add logic for helper argument * minor test fix Signed-off-by: raver119 <raver119@gmail.com> * few tweaks Signed-off-by: raver119 <raver119@gmail.com> * few tweaks Signed-off-by: raver119 <raver119@gmail.com> * skip OpTracker props for non-x86 builds Signed-off-by: raver119 <raver119@gmail.com> * linux arm isn't x86 :) Signed-off-by: raver119 <raver119@gmail.com> * avx-512 Signed-off-by: raver119 <raver119@gmail.com> * CUDA presets fix Signed-off-by: raver119 <raver119@gmail.com> * BUILD_PIC Signed-off-by: raver119 <raver119@gmail.com> * prefetchw for avx2 Signed-off-by: raver119 <raver119@gmail.com> * BUILD_PIC again Signed-off-by: raver119 <raver119@gmail.com>
2019-09-11 20:50:28 +02:00
Nd4jStatus status;
bool hasHelper = false;
// if we have platform-specific helper for this op - invoke it
if (OpRegistrator::getInstance()->hasHelper(this->getOpHash())) {
auto helper = OpRegistrator::getInstance()->getPlatformHelper(this->getOpHash());
if (helper->isUsable(*block)) {
status = helper->invokeHelper(*block);
hasHelper = true;
}
}
// if we don't have platform-specific helper - invoke generic implementation
if (!hasHelper)
status = this->validateAndExecute(*block);
2019-06-06 14:21:15 +02:00
// optionally saving execution time
if (Environment::getInstance()->isProfiling()) {
timeEnd = std::chrono::system_clock::now();
outerTime = std::chrono::duration_cast<std::chrono::nanoseconds>(timeEnd - timeStart).count();
block->setInnerTime(outerTime);
}
if (Environment::getInstance()->isProfiling()) {
auto fp = block->getVariableSpace()->flowPath();
if (fp != nullptr) {
auto p = fp->profile();
if (p != nullptr) {
Nd4jLong memoryAfter = block->workspace() == nullptr ? 0L : block->workspace()->getSpilledSize() + block->workspace()->getUsedSize();
Nd4jLong memoryUsed = memoryAfter - memoryBefore;
p->nodeById(block->nodeId())->setPreparationTime(prepTime);
p->nodeById(block->nodeId())->setExecutionTime(outerTime);
p->nodeById(block->nodeId())->setTotalSize(memoryUsed);
}
}
}
// now we print out all outputs for this node
if (nd4j::Environment::getInstance()->isDebugAndVerbose()) {
auto vs = block->getVariableSpace();
for (int e = 0; e < numOutputs; e++) {
// if given output index doesn't exist - we're done
if (!block->isFastPath()) {
if (!vs->hasVariable(block->nodeId(), e))
break;
} else {
// we have to check either in or out stack, depending on isInplace()
if (block->isInplace()) {
if (block->fastpath_in().size() <= e)
break;
} else {
if (block->fastpath_out().size() <= e)
break;
}
}
auto array = block->isFastPath() ? block->isInplace() ? block->fastpath_in()[e] : block->fastpath_out()[e] : vs->getVariable(block->nodeId(), e)->getNDArray();
auto shape = ShapeUtils::shapeAsString(array);
auto first = array->isEmpty() ? std::string("Empty NDArray") : array->asString(32);
auto type = DataTypeUtils::asString(array->dataType());
nd4j_printf("node_%i:%i result shape: %s; dtype: %s; first values %s\n", block->nodeId(), e, shape.c_str(), type.c_str(), first.c_str());
}
}
return status;
}
void DeclarableOp::overwriteResult(Context &block, int outputIdx, NDArray *array) {
throw std::runtime_error("Overwrite result used!");
//block.pushNDArrayToVariableSpace(block.nodeId(), outputIdx, array);
/*
auto varSpace = block.getVariableSpace();
if (varSpace->hasVariable(block.getNodeId(), outputIdx)) {
auto var = varSpace->getVariable(block.getNodeId(), outputIdx);
if (var->getNDArray() != nullptr && var->isRemovable())
delete var->getNDArray();
var->setNDArray(array);
var->markRemovable(true);
} else {
auto var = new Variable(array, nullptr, block.getNodeId(), outputIdx);
varSpace->putVariable(block.getNodeId(), outputIdx, var);
}
*/
}
void DeclarableOp::overwriteResult(Context &block, int outputIdx, NDArrayList *list) {
throw std::runtime_error("Overwrite result used!");
//block.pushNDArrayListToVariableSpace(block.nodeId(), outputIdx, list);
/*
auto varSpace = block.getVariableSpace();
if (varSpace->hasVariable(block.getNodeId(), outputIdx)) {
auto var = varSpace->getVariable(block.getNodeId(), outputIdx);
var->setNDArrayList(list);
} else {
auto var = new Variable(nullptr, nullptr, block.getNodeId(), outputIdx);
var->setNDArrayList(list);
varSpace->putVariable(block.getNodeId(), outputIdx, var);
}
*/
}
Nd4jStatus nd4j::ops::DeclarableOp::validateArguments(Context& block) {
/*
* We're checking number of T and I arguments. If number of args is finite number - we check strict equality
* If number of args is variable (-1), but variables MUST be present - we check for non-zero number of arguments
*/
if (_descriptor->getNumberOfTArgs() > 0) {
if ((int) block.getTArguments()->size() < _descriptor->getNumberOfTArgs()) {
nd4j_printf("%s: %i T args expected, but %i received\n", this->getOpName()->c_str(), _descriptor->getNumberOfTArgs(), block.getTArguments()->size());
return ND4J_STATUS_BAD_PARAMS;
}
} else
if (_descriptor->getNumberOfTArgs() == -1)
if (block.getTArguments()->size() == 0) {
nd4j_printf("%s: Number of T arguments should be positive number, but got 0 arguments\n", this->getOpName()->c_str());
return ND4J_STATUS_BAD_PARAMS;
}
if (_descriptor->getNumberOfIArgs() > 0) {
if ((int) block.getIArguments()->size() < _descriptor->getNumberOfIArgs()) {
nd4j_printf("%s: %i int args expected, but %i received\n", this->getOpName()->c_str(), _descriptor->getNumberOfIArgs(), block.getIArguments()->size());
return ND4J_STATUS_BAD_PARAMS;
}
} else
if (_descriptor->getNumberOfIArgs() == -1)
if (block.getIArguments()->size() == 0) {
nd4j_printf("%s: Number of Integer arguments should be positive number, but got 0 arguments\n", this->getOpName()->c_str());
return ND4J_STATUS_BAD_PARAMS;
}
return ND4J_STATUS_OK;
}
Nd4jStatus nd4j::ops::DeclarableOp::validateInputDimensions(Context& block, int rank) {
if (block.width() == 0)
return ND4J_STATUS_OK;
for (auto p: *block.inputs()) {
auto v = block.variable(p);
NDArray *aV = v->getNDArray();
if (aV == nullptr)
return ND4J_STATUS_BAD_INPUT;
if (aV->rankOf() != rank)
return ND4J_STATUS_BAD_DIMENSIONS;
}
return ND4J_STATUS_OK;
}
Nd4jStatus nd4j::ops::DeclarableOp::validateInput2D(Context& block) {
return validateInputDimensions(block, 2);
}
Nd4jStatus nd4j::ops::DeclarableOp::validateInput3D(Context& block) {
return validateInputDimensions(block, 3);
}
Nd4jStatus nd4j::ops::DeclarableOp::validateInput4D(Context& block) {
return validateInputDimensions(block, 4);
}
Nd4jStatus nd4j::ops::DeclarableOp::validateNonEmptyInput(Context& block) {
if (this->getOpDescriptor()->getNumberOfInputs() == -2 || this->getOpDescriptor()->getNumberOfInputs() == 0)
return Status::OK();
if (block.width() < 1) {
nd4j_printf("%s: no operands provided for the op", this->getOpName()->c_str());
return ND4J_STATUS_BAD_INPUT;
}
int cnt = 0;
for (auto p: *block.inputs()) {
auto v = block.variable(p);
if (v == nullptr) {
if (this->getOpName() != nullptr) {
nd4j_printf("Node [%i:<%s>]: Variable [%i] (%i:%i) is NULL\n", block.getNodeId(), this->getOpName()->c_str(), cnt, p.first, p.second);
} else {
nd4j_printf("Node [%i:<noname>]: Variable [%i] (%i:%i) is NULL\n", block.getNodeId(), cnt, p.first, p.second);
}
return ND4J_STATUS_BAD_INPUT;
}
if (v->variableType() == VariableType::NDARRAY) {
NDArray *aV = v->getNDArray();
// if array is empty intentionally - we're ok with that
if (v->hasNDArray() && v->isEmpty())
continue;
if (aV == nullptr || !aV->nonNull()) {
if (this->getOpName() != nullptr) {
nd4j_printf("Node [%i:<%s>]: NDArray [%i] (%i:%i) is NULL\n", block.getNodeId(), this->getOpName()->c_str(), cnt, p.first, p.second);
} else {
nd4j_printf("Node [%i:<noname>]: NDArray [%i] (%i:%i) is NULL\n", block.getNodeId(), cnt, p.first, p.second);
}
return ND4J_STATUS_BAD_INPUT;
}
}
cnt++;
}
return ND4J_STATUS_OK;
}
Nd4jStatus nd4j::ops::DeclarableOp::validateOrdersMatch(Context& block) {
if (block.width() == 0)
return ND4J_STATUS_OK;
NDArray *a0 = block.variable(0)->getNDArray();
for (auto p: *block.inputs()) {
auto v = block.variable(p);
NDArray *aV = v->getNDArray();
if (a0->ordering() != aV->ordering())
return ND4J_STATUS_BAD_ORDER;
}
return ND4J_STATUS_OK;
}
nd4j::ResultSet* nd4j::ops::DeclarableOp::execute(std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<Nd4jLong> iArgs, std::initializer_list<bool> bArgs, bool isInplace, nd4j::DataType type) {
std::vector<NDArray*> ins(inputs);
std::vector<double> tas(tArgs);
std::vector<Nd4jLong> ias(iArgs);
std::vector<bool> bas(bArgs);
return this->execute(ins, tas, ias, bas, isInplace, type);
}
Nd4jStatus nd4j::ops::DeclarableOp::execute(std::initializer_list<NDArray*> inputs, std::initializer_list<NDArray*> outputs , std::initializer_list<double> tArgs, std::initializer_list<Nd4jLong> iArgs, std::initializer_list<bool> bArgs, bool isInplace, nd4j::DataType type) {
std::vector<NDArray*> ins(inputs);
std::vector<NDArray*> ous(outputs);
std::vector<double> tas(tArgs);
std::vector<Nd4jLong> ias(iArgs);
std::vector<bool> bas(bArgs);
return this->execute(ins, ous, tas, ias, bas, isInplace, type);
}
Nd4jStatus nd4j::ops::DeclarableOp::execute(nd4j::graph::RandomGenerator& rng, std::initializer_list<NDArray*> inputs, std::initializer_list<NDArray*> outputs , std::initializer_list<double> tArgs, std::initializer_list<Nd4jLong> iArgs, std::initializer_list<bool> bArgs, bool isInplace, nd4j::DataType type) {
std::vector<NDArray*> ins(inputs);
std::vector<NDArray*> ous(outputs);
std::vector<double> tas(tArgs);
std::vector<Nd4jLong> ias(iArgs);
std::vector<bool> bas(bArgs);
return this->execute(rng, ins, ous, tas, ias, bas, isInplace, type);
}
Nd4jStatus nd4j::ops::DeclarableOp::execute(std::vector<NDArray*>& inputs, std::vector<NDArray*>& outputs, std::vector<double>& tArgs, std::vector<Nd4jLong>& iArgs, std::vector<bool>& bArgs, bool isInplace, nd4j::DataType type) {
// TODO: nullptr here might be replaced
nd4j::graph::RandomGenerator rng(0, 0);
return execute(rng, inputs, outputs, tArgs, iArgs, bArgs, isInplace, type);
}
Nd4jStatus nd4j::ops::DeclarableOp::execute(nd4j::graph::RandomGenerator& rng, std::vector<NDArray*>& inputs, std::vector<NDArray*>& outputs, std::vector<double>& tArgs, std::vector<Nd4jLong>& iArgs, std::vector<bool>& bArgs, bool isInplace, nd4j::DataType type) {
VariableSpace variableSpace;
FlowPath fp;
variableSpace.setFlowPath(&fp);
int cnt = -1;
std::vector<int> in;
for (auto v: inputs) {
if (v == nullptr)
continue;
auto var = new Variable(v);
var->markRemovable(false);
in.push_back(cnt);
variableSpace.putVariable(cnt--, var);
}
int et = 0;
for (auto v: outputs) {
auto var = new Variable(v);
var->markRemovable(false);
std::pair<int,int> pair(1, et++);
variableSpace.putVariable(pair, var);
}
Context block(1, &variableSpace, false);
block.fillInputs(in);
block.markInplace(isInplace);
block.setDataType(0, type);
// we need this line for tests basically
//if (rng != nullptr)
block.setRng(rng);
for (int e = 0; e < tArgs.size(); e++)
block.getTArguments()->emplace_back(tArgs.at(e));
// FIXME: iargs should be Nd4jLong
for (int e = 0; e < iArgs.size(); e++)
block.getIArguments()->emplace_back(static_cast<int>(iArgs.at(e)));
for (int e = 0; e < bArgs.size(); e++)
block.getBArguments()->push_back(static_cast<int>(bArgs.at(e)));
Nd4jStatus result = this->execute(&block);
return result;
}
nd4j::ResultSet* nd4j::ops::DeclarableOp::execute(const std::vector<NDArray*>& inputs, const std::vector<double>& tArgs, const std::vector<Nd4jLong>& iArgs, const std::vector<bool>& bArgs, bool isInplace, nd4j::DataType type) {
VariableSpace variableSpace;
//ResultSet arrayList;
FlowPath fp;
variableSpace.setFlowPath(&fp);
int cnt = -1;
std::vector<int> in;
for (auto v: inputs) {
if (v == nullptr)
continue;
auto var = new Variable(v);
var->markRemovable(false);
in.push_back(cnt);
variableSpace.putVariable(cnt--, var);
}
Context block(1, &variableSpace, false);
block.setDataType(0, type);
block.fillInputs(in);
block.markInplace(isInplace);
[WIP] More of CUDA (#95) * initial commit Signed-off-by: raver119 <raver119@gmail.com> * Implementation of hashcode cuda helper. Working edition. * Fixed parallel test input arangements. * Fixed tests for hashcode op. * Fixed shape calculation for image:crop_and_resize op and test. * NativeOps tests. Initial test suite. * Added tests for indexReduce methods. * Added test on execBroadcast with NDArray as dimensions. * Added test on execBroadcastBool with NDArray as dimensions. * Added tests on execPairwiseTransform and execPairwiseTransofrmBool. * Added tests for execReduce with scalar results. * Added reduce tests for non-empty dims array. * Added tests for reduce3. * Added tests for execScalar. * Added tests for execSummaryStats. * - provide cpu/cuda code for batch_to_space - testing it Signed-off-by: Yurii <yurii@skymind.io> * - remove old test for batch_to_space (had wrong format and numbers were not checked) Signed-off-by: Yurii <yurii@skymind.io> * Fixed complilation errors with test. * Added test for execTransformFloat. * Added test for execTransformSame. * Added test for execTransformBool. * Added test for execTransformStrict. * Added tests for execScalar/execScalarBool with TADs. * Added test for flatten. * - provide cpu/cuda code for space_to_Batch operaion Signed-off-by: Yurii <yurii@skymind.io> * Added test for concat. * comment unnecessary stuff in s_t_b Signed-off-by: Yurii <yurii@skymind.io> * Added test for specialConcat. * Added tests for memcpy/set routines. * Fixed pullRow cuda test. * Added pullRow test. * Added average test. * - correct typo in NDArray::applyPairwiseTransform(nd4j::pairwise::BoolOps op...) Signed-off-by: Yurii <yurii@skymind.io> * - debugging and fixing cuda tests in JavaInteropTests file Signed-off-by: Yurii <yurii@skymind.io> * - correct some tests Signed-off-by: Yurii <yurii@skymind.io> * Added test for shuffle. * Fixed ops declarations. * Restored omp and added shuffle test. * Added convertTypes test. * Added tests for execRandom. Eliminated usage of RandomBuffer with NativeOps. * Added sort tests. * Added tests for execCustomOp. * - further debuging and fixing tests terminated with crash Signed-off-by: Yurii <yurii@skymind.io> * Added tests for calculateOutputShapes. * Addded Benchmarks test. * Commented benchmark tests. * change assertion Signed-off-by: raver119 <raver119@gmail.com> * Added tests for apply_sgd op. Added cpu helper for that op. * Implement cuda helper for aplly_sgd op. Fixed tests for NativeOps. * Added test for assign broadcastable. * Added tests for assign_bp op. * Added tests for axpy op. * - assign/execScalar/execTransformAny signature change - minor test fix Signed-off-by: raver119 <raver119@gmail.com> * Fixed axpy op. * meh Signed-off-by: raver119 <raver119@gmail.com> * - fix tests for nativeOps::concat Signed-off-by: Yurii <yurii@skymind.io> * sequential transform/scalar Signed-off-by: raver119 <raver119@gmail.com> * allow nested parallelism Signed-off-by: raver119 <raver119@gmail.com> * assign_bp leak fix Signed-off-by: raver119 <raver119@gmail.com> * block setRNG fix Signed-off-by: raver119 <raver119@gmail.com> * enable parallelism by default Signed-off-by: raver119 <raver119@gmail.com> * enable nested parallelism by default Signed-off-by: raver119 <raver119@gmail.com> * Added cuda implementation for row_count helper. * Added implementation for tnse gains op helper. * - take into account possible situations when input arrays are empty in reduce_ cuda stuff Signed-off-by: Yurii <yurii@skymind.io> * Implemented tsne/edge_forces op cuda-based helper. Parallelized cpu-based helper for edge_forces. * Added kernel for tsne/symmetrized op heleper. * Implementation of tsne/symmetrized op cuda helper. Working edition. * Eliminated waste printfs. * Added test for broadcastgradientargs op. * host-only fallback for empty reduce float Signed-off-by: raver119 <raver119@gmail.com> * - some tests fixes Signed-off-by: Yurii <yurii@skymind.io> * - correct the rest of reduce_ stuff Signed-off-by: Yurii <yurii@skymind.io> * - further correction of reduce_ stuff Signed-off-by: Yurii <yurii@skymind.io> * Added test for Cbow op. Also added cuda implementation for cbow helpers. * - improve code of stack operation for scalar case Signed-off-by: Yurii <yurii@skymind.io> * - provide cuda kernel for gatherND operation Signed-off-by: Yurii <yurii@skymind.io> * Implementation of cbow helpers with cuda kernels. * minor tests tweaks Signed-off-by: raver119 <raver119@gmail.com> * minor tests tweaks Signed-off-by: raver119 <raver119@gmail.com> * - further correction of cuda stuff Signed-off-by: Yurii <yurii@skymind.io> * Implementatation of cbow op helper with cuda kernels. Working edition. * Skip random testing for cudablas case. * lstmBlockCell context fix Signed-off-by: raver119 <raver119@gmail.com> * Added tests for ELU and ELU_BP ops. * Added tests for eq_scalar, gt_scalar, gte_scalar and lte_scalar ops. * Added tests for neq_scalar. * Added test for noop. * - further work on clipbynorm_bp Signed-off-by: Yurii <yurii@skymind.io> * - get rid of concat op call, use instead direct concat helper call Signed-off-by: Yurii <yurii@skymind.io> * lstmBlockCell context fix Signed-off-by: raver119 <raver119@gmail.com> * Added tests for lrelu and lrelu_bp. * Added tests for selu and selu_bp. * Fixed lrelu derivative helpers. * - some corrections in lstm Signed-off-by: Yurii <yurii@skymind.io> * operator * result shape fix Signed-off-by: raver119 <raver119@gmail.com> * - correct typo in lstmCell Signed-off-by: Yurii <yurii@skymind.io> * few tests fixed Signed-off-by: raver119 <raver119@gmail.com> * CUDA inverse broadcast bool fix Signed-off-by: raver119 <raver119@gmail.com> * disable MMAP test for CUDA Signed-off-by: raver119 <raver119@gmail.com> * BooleanOp syncToDevice Signed-off-by: raver119 <raver119@gmail.com> * meh Signed-off-by: raver119 <raver119@gmail.com> * additional data types for im2col/col2im Signed-off-by: raver119 <raver119@gmail.com> * Added test for firas_sparse op. * one more RandomBuffer test excluded Signed-off-by: raver119 <raver119@gmail.com> * Added tests for flatten op. * Added test for Floor op. * bunch of tests fixed Signed-off-by: raver119 <raver119@gmail.com> * mmulDot tests fixed Signed-off-by: raver119 <raver119@gmail.com> * more tests fixed Signed-off-by: raver119 <raver119@gmail.com> * Implemented floordiv_bp op and tests. * Fixed scalar case with cuda implementation for bds. * - work on cuda kernel for clip_by_norm backprop op is completed Signed-off-by: Yurii <yurii@skymind.io> * Eliminate cbow crach. * more tests fixed Signed-off-by: raver119 <raver119@gmail.com> * more tests fixed Signed-off-by: raver119 <raver119@gmail.com> * Eliminated abortion with batched nlp test. * more tests fixed Signed-off-by: raver119 <raver119@gmail.com> * Fixed shared flag initializing. * disabled bunch of cpu workspaces tests Signed-off-by: raver119 <raver119@gmail.com> * scalar operators fix: missing registerSpecialUse call Signed-off-by: raver119 <raver119@gmail.com> * Fixed logdet for cuda and tests. * - correct clipBynorm_bp Signed-off-by: Yurii <yurii@skymind.io> * Fixed crop_and_resize shape datatype. * - correct some mmul tests Signed-off-by: Yurii <yurii@skymind.io>
2019-08-02 19:01:03 +02:00
// block.setRNG(ProviderRNG::getInstance().getRNG());
2019-06-06 14:21:15 +02:00
for (int e = 0; e < tArgs.size(); e++)
block.getTArguments()->emplace_back(tArgs.at(e));
for (int e = 0; e < iArgs.size(); e++)
block.getIArguments()->emplace_back(iArgs.at(e));
for (int e = 0; e < bArgs.size(); e++)
block.getBArguments()->push_back(bArgs.at(e));
Nd4jStatus status = this->execute(&block);
auto arrayList = new ResultSet();
if (isInplace)
arrayList->setNonRemovable();
arrayList->setStatus(status);
if (status != ND4J_STATUS_OK)
return arrayList;
for (int e = 0; e < DataTypeUtils::max<int>(); e++) {
std::pair<int,int> pair(1, e);
if (variableSpace.hasVariable(pair)) {
auto var = variableSpace.getVariable(pair);
auto arr = var->getNDArray();
if (!arr->isAttached()) {
var->markRemovable(false);
arr->setContext(nd4j::LaunchContext ::defaultContext());
arrayList->push_back(arr);
} else {
arrayList->push_back(arr->detach());
}
} else
break;
}
return arrayList;
}
nd4j::ResultSet* nd4j::ops::DeclarableOp::execute(const nd4j::OpArgsHolder& holder, bool isInplace) {
return execute(holder.getInArrs(), holder.getTArgs(), holder.getIArgs(), holder.getBArgs(), isInplace, nd4j::DataType::DOUBLE);
}
Nd4jStatus nd4j::ops::DeclarableOp::validateInputDimensionsMatch(Context& block) {
if (block.width() == 0)
return ND4J_STATUS_OK;
NDArray *a0 = block.array(0);
for (int e = 0; e < block.width(); e++) {
auto aV = block.array(e);
if (!shape::equalsSoft(a0->getShapeInfo(), aV->getShapeInfo()))
return ND4J_STATUS_BAD_DIMENSIONS;
}
return ND4J_STATUS_OK;
}
Nd4jStatus nd4j::ops::DeclarableOp::validateInputLengthMatch(Context& block) {
if (block.width() == 0)
return ND4J_STATUS_OK;
Nd4jLong l0 = block.array(0)->lengthOf();
for (uint32_t e = 0; e < block.width(); e++) {
if (l0 != block.array(e)->lengthOf())
return ND4J_STATUS_BAD_LENGTH;
}
return ND4J_STATUS_OK;
}
void DeclarableOp::registerTypes() {
this->getOpDescriptor()->setSameMode(true);
}
/*
template <typename T>
int* nd4j::ops::DeclarableOp::calculateOutputShape(int* inputShape, nd4j::graph::Block& block) {
// default implementation suits transform, so just returns the same shape
int* newshape;
ALLOCATE(newshape, block.getWorkspace(), shape::shapeInfoLength(inputShape), int);
memcpy(newshape, inputShape, shape::shapeInfoByteLength(inputShape));
return newshape;
}
*/
}
}