cavis/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp

1145 lines
51 KiB
C++
Raw Normal View History

2019-06-06 14:21:15 +02:00
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
2019-06-06 14:21:15 +02:00
//
#include <ops/declarable/DeclarableOp.h>
#include <graph/Status.h>
2019-06-06 14:21:15 +02:00
#include <helpers/ShapeUtils.h>
#include <array/NDArrayFactory.h>
2019-06-06 14:21:15 +02:00
#include <exceptions/graph_exception.h>
#include <graph/exceptions/unresolved_input_exception.h>
Platform helpers (#8216) * platform helpers draft Signed-off-by: raver119 <raver119@gmail.com> * typo Signed-off-by: raver119 <raver119@gmail.com> * disable platform cmake Signed-off-by: raver119 <raver119@gmail.com> * another draft Signed-off-by: raver119 <raver119@gmail.com> * mkldnn convolution refactored Signed-off-by: raver119 <raver119@gmail.com> * minor tweaks Signed-off-by: raver119 <raver119@gmail.com> * one more safety check Signed-off-by: raver119 <raver119@gmail.com> * prototype works Signed-off-by: raver119 <raver119@gmail.com> * meh Signed-off-by: raver119 <raver119@gmail.com> * force static library mode for mkldnn Signed-off-by: raver119 <raver119@gmail.com> * - ismax fix - experimental arg fix - don't enforce openblas on Apple hardware Signed-off-by: raver119 <raver119@gmail.com> * bunch of small fixes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * declare concurrent Signed-off-by: raver119@gmail.com <raver119@gmail.com> * - MKLDNN version upgrade to 1.0.2 - avgpool2d/maxpool2d APIs update Signed-off-by: raver119 <raver119@gmail.com> * - avgpool2d_bp/maxpool2d_bp APIs update Signed-off-by: raver119 <raver119@gmail.com> * - conv2d/batchnorm APIs update Signed-off-by: raver119 <raver119@gmail.com> * - lrn/conv2d_bp/conv3d/conv3d_bp APIs update Signed-off-by: raver119 <raver119@gmail.com> * all ops converted to MKLDNN 1.x Signed-off-by: raver119 <raver119@gmail.com> * bunch of tweaks Signed-off-by: raver119 <raver119@gmail.com> * namespace for platform helpers Signed-off-by: raver119 <raver119@gmail.com> * make sure platform helpers aren't opimized out Signed-off-by: raver119 <raver119@gmail.com> * build cpu_features on x86 systems Signed-off-by: raver119 <raver119@gmail.com> * build cpu_features on x86 systems Signed-off-by: raver119 <raver119@gmail.com> * more of cpu_features Signed-off-by: raver119 <raver119@gmail.com> * - mkldnn removed from java - cpu_features checks in CpuNDArrayFactory Signed-off-by: raver119 <raver119@gmail.com> * F16C definition renamed Signed-off-by: raver119 <raver119@gmail.com> * some mkldnn rearrangements Signed-off-by: raver119 <raver119@gmail.com> * check supported instructions before doing anything Signed-off-by: raver119 <raver119@gmail.com> * typo Signed-off-by: raver119 <raver119@gmail.com> * missied impl Signed-off-by: raver119 <raver119@gmail.com> * BUILD_PIC option Signed-off-by: raver119 <raver119@gmail.com> * conv2d fix Signed-off-by: raver119 <raver119@gmail.com> * avgpool3d fix Signed-off-by: raver119 <raver119@gmail.com> * avgpool3d_bp fix Signed-off-by: raver119 <raver119@gmail.com> * avgpool2d_bp leak fix Signed-off-by: raver119 <raver119@gmail.com> * avgpool3d_bp leak fix Signed-off-by: raver119 <raver119@gmail.com> * maxpool bp leaks fixed Signed-off-by: raver119 <raver119@gmail.com> * printf removed Signed-off-by: raver119 <raver119@gmail.com> * batchnorm fix Signed-off-by: raver119 <raver119@gmail.com> * AVX warning/error polishing Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * More polish Signed-off-by: AlexDBlack <blacka101@gmail.com> * Polish Signed-off-by: AlexDBlack <blacka101@gmail.com> * remove previous MKL-DNN support layer Signed-off-by: raver119 <raver119@gmail.com> * avx2 tweak Signed-off-by: raver119 <raver119@gmail.com> * allow static for apple Signed-off-by: raver119@gmail.com <raver119@gmail.com> * exclude mkldnn in one more place Signed-off-by: raver119 <raver119@gmail.com> * exclude mkldnn in one more place Signed-off-by: raver119 <raver119@gmail.com> * restore OPENBLAS_PATH use Signed-off-by: raver119 <raver119@gmail.com> * add runtime check for avx/avx2 support Signed-off-by: raver119 <raver119@gmail.com> * convolution_auto Signed-off-by: raver119 <raver119@gmail.com> * Add logic for helper argument * minor test fix Signed-off-by: raver119 <raver119@gmail.com> * few tweaks Signed-off-by: raver119 <raver119@gmail.com> * few tweaks Signed-off-by: raver119 <raver119@gmail.com> * skip OpTracker props for non-x86 builds Signed-off-by: raver119 <raver119@gmail.com> * linux arm isn't x86 :) Signed-off-by: raver119 <raver119@gmail.com> * avx-512 Signed-off-by: raver119 <raver119@gmail.com> * CUDA presets fix Signed-off-by: raver119 <raver119@gmail.com> * BUILD_PIC Signed-off-by: raver119 <raver119@gmail.com> * prefetchw for avx2 Signed-off-by: raver119 <raver119@gmail.com> * BUILD_PIC again Signed-off-by: raver119 <raver119@gmail.com>
2019-09-11 20:50:28 +02:00
#include <ops/declarable/OpRegistrator.h>
#include <exceptions/datatype_exception.h>
#include <helpers/StringUtils.h>
#include <cstdarg>
2019-06-06 14:21:15 +02:00
namespace sd {
2019-06-06 14:21:15 +02:00
namespace ops {
Nd4jStatus conditionHelper(const char *file, int line, int condition, int argNumber, const char *format, ...) {
if (!condition) {
va_list args;
printf("Error at [%s:%i:%i]:\n", file, line, argNumber);
va_start(args, format);
vprintf(format, args);
va_end(args);
printf("\n");
fflush(stdout);
return ND4J_STATUS_BAD_PARAMS;
}
return ND4J_STATUS_OK;
}
DeclarableOp::DeclarableOp() {
// no-op
}
DeclarableOp::DeclarableOp(const char *name, bool isLogical) {
_descriptor = new OpDescriptor(name, isLogical);
_name = name;
2019-06-06 14:21:15 +02:00
}
DeclarableOp::DeclarableOp(const char *name, int numInputs, bool scalar) {
_descriptor = new OpDescriptor(numInputs, name, scalar);
_name = name;
2019-06-06 14:21:15 +02:00
}
DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace) {
_descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace);
_name = opName;
2019-06-06 14:21:15 +02:00
}
DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, bool divergent) {
_descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace, divergent);
_name = opName;
2019-06-06 14:21:15 +02:00
}
DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs) {
_descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace, tArgs, iArgs);
_name = opName;
2019-06-06 14:21:15 +02:00
}
DeclarableOp::~DeclarableOp() {
if (_descriptor != nullptr)
delete _descriptor;
if (_scalar != nullptr)
delete _scalar;
}
OpDescriptor* DeclarableOp::getOpDescriptor() {
return _descriptor;
}
std::string *DeclarableOp::getOpName() {
return _descriptor->getOpName();
}
Nd4jLong DeclarableOp::getOpHash() {
return _descriptor->getHash();
}
Nullify (#304) * initial commit Signed-off-by: raver119 <raver119@gmail.com> * bunch of tweaks Signed-off-by: raver119 <raver119@gmail.com> * hamming distance nullification Signed-off-by: raver119 <raver119@gmail.com> * Add output array value assignment for testing/debugging Signed-off-by: Alex Black <blacka101@gmail.com> * don't assign empty arrays Signed-off-by: raver119 <raver119@gmail.com> * conv2d/conv3d/depthwise2d nullified Signed-off-by: raver119 <raver119@gmail.com> * conv2d/conv3d/depthwise2d nullified Signed-off-by: raver119 <raver119@gmail.com> * conv2d/conv3d/depthwise2d nullified Signed-off-by: raver119 <raver119@gmail.com> * few more fixes Signed-off-by: raver119 <raver119@gmail.com> * im2col Signed-off-by: raver119 <raver119@gmail.com> * pooling? Signed-off-by: raver119 <raver119@gmail.com> * more nullified Signed-off-by: raver119 <raver119@gmail.com> * ismax nullified Signed-off-by: raver119 <raver119@gmail.com> * rollback ismax nullification Signed-off-by: raver119 <raver119@gmail.com> * synchronized cublas handle use on per-device basis Signed-off-by: raver119 <raver119@gmail.com> * hiding method from jcpp Signed-off-by: raver119 <raver119@gmail.com> * get rid of test assigns in DeclarableOp Signed-off-by: raver119 <raver119@gmail.com> * get rid of assigns Signed-off-by: raver119 <raver119@gmail.com> * proper deviceId is back Signed-off-by: raver119 <raver119@gmail.com> * include fixed Signed-off-by: raver119 <raver119@gmail.com> Co-authored-by: Alex Black <blacka101@gmail.com>
2020-03-20 06:49:28 +01:00
sd::NDArray* sd::ops::DeclarableOp::getNullifiedZ(Context& block, int inputId) {
auto result = getZ(block, inputId);
if (result != nullptr && !block.isInplace())
result->nullify();
return result;
}
2019-06-06 14:21:15 +02:00
sd::NDArray* sd::ops::DeclarableOp::getZ(Context& ctx, int inputId) {
2019-06-06 14:21:15 +02:00
NDArray* z = nullptr;
if (ctx.isFastPath()) {
if (ctx.fastpath_out().size() <= inputId) {
if (ctx.isInplace()) {
z = ctx.fastpath_in()[inputId];
} else
throw std::runtime_error("fastpath_out: unresolved output array");
} else {
z = ctx.fastpath_out()[inputId];
}
} else {
std::pair<int, int> pair(ctx.nodeId(), inputId);
if (ctx.isInplace()) {
z = ctx.variable(inputId)->getNDArray();
// hypothetically it's possible to have no variable. chances are low, but who knows. let's just create it for now
if (!ctx.getVariableSpace()->hasVariable(pair)) {
auto var = new Variable();
ctx.getVariableSpace()->putVariable(pair, var);
}
// now we're saving input array as output array
auto var = ctx.getVariableSpace()->getVariable(pair);
var->markRemovable(false);
var->setNDArray(z);
} else if (!ctx.isInplace()) {
auto var = ctx.variable(pair);
if (var->getNDArray() != nullptr && var->getNDArray()->nonNull()) {
z = var->getNDArray();
} else {
nd4j_printf("Can't get Z variable for node_%i!\n", ctx.nodeId());
}
} else {
nd4j_printf("BOOM!\n", "");
throw std::runtime_error("Boom!");
}
}
return z;
}
int sd::ops::DeclarableOp::prepareOutputs(Context &ctx) {
2019-06-06 14:21:15 +02:00
auto workspace = ctx.getWorkspace();
GraphProfile *prof = nullptr;
NodeProfile *node = nullptr;
std::chrono::time_point<std::chrono::system_clock> inputEnd, inputStart, shapeStart, shapeEnd, arrayStart, arrayEnd;
bool canUseFastPath = true;
2019-06-06 14:21:15 +02:00
auto fp = ctx.isFastPath();
if (Environment::getInstance().isProfiling()) {
2019-06-06 14:21:15 +02:00
if (ctx.getVariableSpace() != nullptr && ctx.getVariableSpace()->flowPath() != nullptr) {
prof = ctx.getVariableSpace()->flowPath()->profile();
node = prof->nodeById(ctx.nodeId());
}
}
if (ctx.isInplace()) {
if (Environment::getInstance().isProfiling() && node != nullptr) {
if (fp) {
//
} else {
for (auto p: *ctx.inputs()) {
auto var = ctx.variable(p);
if (var->variableType() == VariableType::NDARRAY) {
NDArray *array = var->getNDArray();
node->addInputShape(array->shapeInfo());
node->addOutputShape(array->shapeInfo());
}
}
}
}
// if that's not fp, we can still propagate inputs and outputs
if (!fp) {
int cnt = 0;
auto id = ctx.nodeId();
auto vs = ctx.getVariableSpace();
for (auto p: *ctx.inputs()) {
auto var = ctx.variable(p);
if (var->variableType() == VariableType::NDARRAY) {
NDArray *array = var->getNDArray();
ctx.setInputArray(cnt, array);
ctx.setOutputArray(cnt, array);
// in case of this override we might need to update outputs in the Graph VariableSpace as well
if (vs != nullptr) {
if (vs->hasVariable(id, cnt)) {
auto v2 = vs->getVariable(id, cnt);
if (!v2->hasNDArray()) {
v2->setNDArray(array);
v2->markRemovable(false);
}
} else {
auto v2 = vs->putVariable(id, cnt, array);
v2->markRemovable(false);
}
}
cnt++;
} else {
canUseFastPath = false;
}
}
}
if (!canUseFastPath)
ctx.forbidFastPath(true);
2019-06-06 14:21:15 +02:00
// do nothing, getZ result will do the trick
return static_cast<int>(ctx.width());
} else {
// if op is not inplace - we should pre-allocate arrays
ShapeList inSha;
int results = 0;
if (Environment::getInstance().isProfiling() && node != nullptr)
2019-06-06 14:21:15 +02:00
inputStart = std::chrono::system_clock::now();
int cntIn = 0;
// we build list of input shapes
if (fp) {
2019-06-06 14:21:15 +02:00
for (const auto p:ctx.fastpath_in()) {
Legacy API changes (#441) * initial commit Signed-off-by: raver119@gmail.com <raver119@gmail.com> * another initial commit Signed-off-by: raver119@gmail.com <raver119@gmail.com> * another initial commit Signed-off-by: raver119@gmail.com <raver119@gmail.com> * one more initial commit Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next step Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next step Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next step Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next step Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Refactored buffer() and shapeInfo() methods usage with NDArray class. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt Graph class methods to use const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt choose op to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt where op shape method to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt lstsq op to use constant empty shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt matrix_diag_part op shape routine to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt determinant ops to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt mean_pairwssqerr_loss ops to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt ops shape methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt shape methods for loss ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt log_loss op shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt shape methods for ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt dilation2d ops shape methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted deconv2d ops shape methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted dynamicRNN op shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape methods for ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape methods for lstm layer ops. Signed-off-by: shugeo <sgazeos@gmail.com> * few updates Signed-off-by: raver119@gmail.com <raver119@gmail.com> * first cuda tweak Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Adopt constant shapes for sconv2d ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt constant shapes for gru ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt constant shapes with shape methods for segment ops and so on. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted constant shapes with unsorted_segment_* ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted constant shapes with gamma op shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape methods of reduce_stddev ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape methods for reduce_* ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt shape method for squeeze op. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt strided_slice shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored concat op shape method to adopt constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape method for mirror_pad op. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted split op shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted tile ops shape methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Added const cast for mkldnn routines handles. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored logSoftMaxForVector_ routine to conform with proper data and shape pointer casts. Signed-off-by: shugeo <sgazeos@gmail.com> * Cosmetic changes to proper usage of constant pointers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored a couple shape comparators for strides and addBias helpers to proper use data pointers with inplace option. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored depthToSpace helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored histogram helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored im2col helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored gather and gatherND helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage on percentile helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed gather shape with helpers and range buffer usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with space to depth helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage and constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with LUP decomposition> Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored onehot_ helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored pad and prefix to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactoed softmax helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed space to batch helpers to use buffers properly. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed stack and split helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with sparse to dense helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with mindistance_ helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with tile helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed constant shape usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed constant shape usage with legacy pairwise bool ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored a couple of methods to adopt constant shape usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed broadcasting with constant shape." Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const usage with inplace reverse and constant shapes with legacy reduction. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored legacy ops with const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored sort to adopt constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected sort for constant shape usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed constant shape usage with special methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored Context to conform with constant shape usage. Signed-off-by: shugeo <sgazeos@gmail.com> * CUDA broadcasting headers Signed-off-by: raver119@gmail.com <raver119@gmail.com> * pairwise/indexreduce/random headers Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Refactored native ops to adopt constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * legacy reduce3/scalar headers Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Corrected pullRow signature and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected routines to proper use of constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored tests to use constant shapes properly. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored legacy ops tests to use constant shapes properly. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored buffer usage with NDArray tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed native ops tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed special concat routine. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with test. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with a test. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored TAD.h and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored calcStrides* routines to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed miscelaneous errors with constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * NativeOps const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Corrected definitions for declared functions. Signed-off-by: shugeo <sgazeos@gmail.com> * NativeOps const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * few more const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Fixed const shapes with shape routines. Signed-off-by: shugeo <sgazeos@gmail.com> * few more const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Fixed shape method for broadcastable case. Signed-off-by: shugeo <sgazeos@gmail.com> * few more const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * xw_plus_b BP shape fn restored Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Fixed signatures with broadcasting. Signed-off-by: shugeo <sgazeos@gmail.com> * Repaired backprops shape methods for a set of operations. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored broadcast bool for cuda. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored methods for 3 args with const qualifier. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed a couple of kernel signatures for broadcasting. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed kernels signatures for const buffers and shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored pairwise methods to persistent buffers and shapes usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt const to buffers and shapes with kernels. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt const to buffers and shapes with scalar kernels. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored indexreduce kernels signatures to use const buffers and shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored pairwise kernels to adopt cons shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored pairwise bool kernels to adopt cons shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored random special ops to conform with const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored native ops to conform with const shapes and buffers under cuda platform. Signed-off-by: shugeo <sgazeos@gmail.com> * Cosmetical changes only. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shapes and buffers error. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected start pos routine. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored methods to conform with const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored helpers to use proper methods instead. Signed-off-by: shugeo <sgazeos@gmail.com> * bunch of changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next bunch of changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next bunch of changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Fixed execScalar declaration. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed execScalar declaration. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected const shape cases with sort and so on. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shapes for sort. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored kernel declarations to adopt const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed kernels declarations to adopt const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected kernel declarations to adopt const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed kernels declarations to adopt const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed segment helpers kernels declarations and so on to adopt const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shape usage with segment and solve helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed kernel declaration with adjustWeight helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed cuda implementations for constant shape helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted const shape usage with kernels. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted top_k kernels to use const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected kernels declarations to adopt const shapes with helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored NDArray definitions to adopt const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shapes with image suppression helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Slight improvement with buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored buffer usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored buffer usage with tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shape usage with definitions. Signed-off-by: shugeo <sgazeos@gmail.com> * minor updates on cpu side Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Refactored const shape usage with ConstantDescritor and native ops with cuda platform. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored tear and tile kernels to adopt with const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * softmax_loop fix Signed-off-by: raver119 <raver119@gmail.com> * update missing signature Signed-off-by: raver119@gmail.com <raver119@gmail.com> * softmax again Signed-off-by: raver119@gmail.com <raver119@gmail.com> * few more missing consts Signed-off-by: raver119 <raver119@gmail.com> * new methods updated Signed-off-by: raver119@gmail.com <raver119@gmail.com> Co-authored-by: shugeo <sgazeos@gmail.com>
2020-05-09 07:06:14 +02:00
inSha.push_back(p == nullptr ? nullptr : p->shapeInfo());
2019-06-06 14:21:15 +02:00
}
} else {
int arrCnt = 0;
2019-06-06 14:21:15 +02:00
for (auto p: *ctx.inputs()) {
auto var = ctx.variable(p);
if (var->variableType() == VariableType::NDARRAY) {
NDArray *array = var->getNDArray();
if (array == nullptr)
throw unresolved_input_exception::build("Variable wasn't resolved prior shape calculation", p);
Legacy API changes (#441) * initial commit Signed-off-by: raver119@gmail.com <raver119@gmail.com> * another initial commit Signed-off-by: raver119@gmail.com <raver119@gmail.com> * another initial commit Signed-off-by: raver119@gmail.com <raver119@gmail.com> * one more initial commit Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next step Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next step Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next step Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next step Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Refactored buffer() and shapeInfo() methods usage with NDArray class. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt Graph class methods to use const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt choose op to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt where op shape method to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt lstsq op to use constant empty shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt matrix_diag_part op shape routine to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt determinant ops to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt mean_pairwssqerr_loss ops to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt ops shape methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt shape methods for loss ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt log_loss op shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt shape methods for ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt dilation2d ops shape methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted deconv2d ops shape methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted dynamicRNN op shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape methods for ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape methods for lstm layer ops. Signed-off-by: shugeo <sgazeos@gmail.com> * few updates Signed-off-by: raver119@gmail.com <raver119@gmail.com> * first cuda tweak Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Adopt constant shapes for sconv2d ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt constant shapes for gru ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt constant shapes with shape methods for segment ops and so on. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted constant shapes with unsorted_segment_* ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted constant shapes with gamma op shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape methods of reduce_stddev ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape methods for reduce_* ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt shape method for squeeze op. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt strided_slice shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored concat op shape method to adopt constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape method for mirror_pad op. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted split op shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted tile ops shape methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Added const cast for mkldnn routines handles. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored logSoftMaxForVector_ routine to conform with proper data and shape pointer casts. Signed-off-by: shugeo <sgazeos@gmail.com> * Cosmetic changes to proper usage of constant pointers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored a couple shape comparators for strides and addBias helpers to proper use data pointers with inplace option. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored depthToSpace helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored histogram helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored im2col helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored gather and gatherND helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage on percentile helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed gather shape with helpers and range buffer usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with space to depth helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage and constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with LUP decomposition> Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored onehot_ helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored pad and prefix to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactoed softmax helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed space to batch helpers to use buffers properly. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed stack and split helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with sparse to dense helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with mindistance_ helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with tile helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed constant shape usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed constant shape usage with legacy pairwise bool ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored a couple of methods to adopt constant shape usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed broadcasting with constant shape." Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const usage with inplace reverse and constant shapes with legacy reduction. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored legacy ops with const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored sort to adopt constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected sort for constant shape usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed constant shape usage with special methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored Context to conform with constant shape usage. Signed-off-by: shugeo <sgazeos@gmail.com> * CUDA broadcasting headers Signed-off-by: raver119@gmail.com <raver119@gmail.com> * pairwise/indexreduce/random headers Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Refactored native ops to adopt constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * legacy reduce3/scalar headers Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Corrected pullRow signature and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected routines to proper use of constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored tests to use constant shapes properly. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored legacy ops tests to use constant shapes properly. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored buffer usage with NDArray tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed native ops tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed special concat routine. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with test. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with a test. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored TAD.h and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored calcStrides* routines to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed miscelaneous errors with constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * NativeOps const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Corrected definitions for declared functions. Signed-off-by: shugeo <sgazeos@gmail.com> * NativeOps const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * few more const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Fixed const shapes with shape routines. Signed-off-by: shugeo <sgazeos@gmail.com> * few more const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Fixed shape method for broadcastable case. Signed-off-by: shugeo <sgazeos@gmail.com> * few more const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * xw_plus_b BP shape fn restored Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Fixed signatures with broadcasting. Signed-off-by: shugeo <sgazeos@gmail.com> * Repaired backprops shape methods for a set of operations. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored broadcast bool for cuda. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored methods for 3 args with const qualifier. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed a couple of kernel signatures for broadcasting. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed kernels signatures for const buffers and shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored pairwise methods to persistent buffers and shapes usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt const to buffers and shapes with kernels. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt const to buffers and shapes with scalar kernels. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored indexreduce kernels signatures to use const buffers and shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored pairwise kernels to adopt cons shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored pairwise bool kernels to adopt cons shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored random special ops to conform with const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored native ops to conform with const shapes and buffers under cuda platform. Signed-off-by: shugeo <sgazeos@gmail.com> * Cosmetical changes only. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shapes and buffers error. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected start pos routine. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored methods to conform with const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored helpers to use proper methods instead. Signed-off-by: shugeo <sgazeos@gmail.com> * bunch of changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next bunch of changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next bunch of changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Fixed execScalar declaration. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed execScalar declaration. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected const shape cases with sort and so on. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shapes for sort. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored kernel declarations to adopt const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed kernels declarations to adopt const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected kernel declarations to adopt const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed kernels declarations to adopt const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed segment helpers kernels declarations and so on to adopt const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shape usage with segment and solve helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed kernel declaration with adjustWeight helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed cuda implementations for constant shape helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted const shape usage with kernels. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted top_k kernels to use const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected kernels declarations to adopt const shapes with helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored NDArray definitions to adopt const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shapes with image suppression helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Slight improvement with buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored buffer usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored buffer usage with tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shape usage with definitions. Signed-off-by: shugeo <sgazeos@gmail.com> * minor updates on cpu side Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Refactored const shape usage with ConstantDescritor and native ops with cuda platform. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored tear and tile kernels to adopt with const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * softmax_loop fix Signed-off-by: raver119 <raver119@gmail.com> * update missing signature Signed-off-by: raver119@gmail.com <raver119@gmail.com> * softmax again Signed-off-by: raver119@gmail.com <raver119@gmail.com> * few more missing consts Signed-off-by: raver119 <raver119@gmail.com> * new methods updated Signed-off-by: raver119@gmail.com <raver119@gmail.com> Co-authored-by: shugeo <sgazeos@gmail.com>
2020-05-09 07:06:14 +02:00
inSha.push_back(array->shapeInfo());
// we're also filling ctx with arrays
if (canUseFastPath)
ctx.setInputArray(arrCnt++, array);
} else {
canUseFastPath = false;
2019-06-06 14:21:15 +02:00
}
cntIn++;
}
}
// if we override shape function, we'll return size of fastPath
if (fp && ctx.shapeFunctionOverride()) {
return (int) ctx.fastpath_out().size();
}
2019-06-06 14:21:15 +02:00
// optionally saving input time
if (Environment::getInstance().isProfiling() && node != nullptr) {
2019-06-06 14:21:15 +02:00
inputEnd = std::chrono::system_clock::now();
auto inputTime = std::chrono::duration_cast<std::chrono::nanoseconds>(inputEnd - inputStart).count();
node->setInputTime(inputTime);
// saving output shapes in profile
for (int e = 0; e < inSha.size(); e++)
node->addInputShape(inSha.at(e));
2019-06-06 14:21:15 +02:00
shapeStart = std::chrono::system_clock::now();
}
auto outSha = this->calculateOutputShape(&inSha, ctx);
results = outSha->size();
// optionally saving shapeTime
if (Environment::getInstance().isProfiling() && node != nullptr) {
2019-06-06 14:21:15 +02:00
shapeEnd = std::chrono::system_clock::now();
auto prepTime = std::chrono::duration_cast<std::chrono::nanoseconds>(shapeEnd - shapeStart).count();
node->setShapeFunctionTime(prepTime);
// saving output shapes in profile
for (int e = 0; e < outSha->size(); e++)
node->addOutputShape(outSha->at(e));
2019-06-06 14:21:15 +02:00
arrayStart = std::chrono::system_clock::now();
}
int cnt = 0;
2019-06-06 14:21:15 +02:00
for (auto out: *outSha->asVector()) {
if (!fp) {
2019-06-06 14:21:15 +02:00
// we need to check, if Z is really needed
std::pair<int, int> pair(ctx.nodeId(), cnt++);
if (!ctx.isValueAvailable(pair.second)) {
if (Environment::getInstance().isDebugAndVerbose())
2019-06-06 14:21:15 +02:00
shape::printShapeInfoLinear("Going to create variable with shape", out);
Nullify (#304) * initial commit Signed-off-by: raver119 <raver119@gmail.com> * bunch of tweaks Signed-off-by: raver119 <raver119@gmail.com> * hamming distance nullification Signed-off-by: raver119 <raver119@gmail.com> * Add output array value assignment for testing/debugging Signed-off-by: Alex Black <blacka101@gmail.com> * don't assign empty arrays Signed-off-by: raver119 <raver119@gmail.com> * conv2d/conv3d/depthwise2d nullified Signed-off-by: raver119 <raver119@gmail.com> * conv2d/conv3d/depthwise2d nullified Signed-off-by: raver119 <raver119@gmail.com> * conv2d/conv3d/depthwise2d nullified Signed-off-by: raver119 <raver119@gmail.com> * few more fixes Signed-off-by: raver119 <raver119@gmail.com> * im2col Signed-off-by: raver119 <raver119@gmail.com> * pooling? Signed-off-by: raver119 <raver119@gmail.com> * more nullified Signed-off-by: raver119 <raver119@gmail.com> * ismax nullified Signed-off-by: raver119 <raver119@gmail.com> * rollback ismax nullification Signed-off-by: raver119 <raver119@gmail.com> * synchronized cublas handle use on per-device basis Signed-off-by: raver119 <raver119@gmail.com> * hiding method from jcpp Signed-off-by: raver119 <raver119@gmail.com> * get rid of test assigns in DeclarableOp Signed-off-by: raver119 <raver119@gmail.com> * get rid of assigns Signed-off-by: raver119 <raver119@gmail.com> * proper deviceId is back Signed-off-by: raver119 <raver119@gmail.com> * include fixed Signed-off-by: raver119 <raver119@gmail.com> Co-authored-by: Alex Black <blacka101@gmail.com>
2020-03-20 06:49:28 +01:00
// we're creating non-initialized array here
auto outArr = new NDArray(out, true, ctx.launchContext(), false);
2019-06-06 14:21:15 +02:00
ctx.pushNDArrayToVariableSpace(pair, outArr);
if (canUseFastPath)
ctx.setOutputArray(pair.second, outArr);
2019-06-06 14:21:15 +02:00
} else {
// validate/compare shapes here. existent vs provided in outSha
auto var = ctx.variable(pair);
auto shape = var->getNDArray()->shapeInfo();
if (canUseFastPath)
ctx.setOutputArray(pair.second, var->getNDArray());
if (!shape::equalsSoft(out, shape) || shape::isEmpty(out) != shape::isEmpty(shape)) {
2019-06-06 14:21:15 +02:00
auto eShape = ShapeUtils::shapeAsString(out);
auto aShape = ShapeUtils::shapeAsString(shape);
2021-02-01 06:31:20 +01:00
auto eShapeInfoString = ShapeUtils::shapeInfoAsString(out);
auto aShapeInfoString = ShapeUtils::shapeInfoAsString(shape);
2019-06-06 14:21:15 +02:00
//outSha->destroy();
delete outSha;
2021-02-01 06:31:20 +01:00
nd4j_printf("Expected vs provided shapes mismatch %s vs %s at index %i with expected shape info %s and output shape info %s\n", eShape.c_str(), aShape.c_str(), pair.second,eShapeInfoString.c_str(),aShapeInfoString.c_str());
2019-06-06 14:21:15 +02:00
throw std::runtime_error("Expected vs provided shapes mismatch");
}
//checking out data type equality
if (ArrayOptions::dataType(out) != ArrayOptions::dataType(shape)) {
std::string msg = "Provided array [" + StringUtils::valueToString<int>(pair.second) + "] has unexpected data type";
throw sd::datatype_exception::build(msg, ArrayOptions::dataType(out), ArrayOptions::dataType(shape));
}
2019-06-06 14:21:15 +02:00
}
} else {
auto fout = ctx.fastpath_out();
auto idx = cnt++;
if (fout.size() <= idx) {
// array doesnt exist
auto outArr = new NDArray(out, true, ctx.launchContext());
ctx.setOutputArray(idx, outArr, true);
} else {
auto array = fout[idx];
2021-02-01 06:31:20 +01:00
int shapeEquals = shape::equalsSoft(out, array->shapeInfo());
int arrayEmpty = array->isEmpty();
// checking out shape equality
2021-02-01 06:31:20 +01:00
if (!shapeEquals || arrayEmpty) {
2019-06-06 14:21:15 +02:00
auto eShape = ShapeUtils::shapeAsString(out);
auto aShape = ShapeUtils::shapeAsString(array->shapeInfo());
2021-02-01 06:31:20 +01:00
auto eShapeInfoString = ShapeUtils::shapeInfoAsString(out);
auto aShapeInfoString = ShapeUtils::shapeInfoAsString(array->shapeInfo());
if(eShapeInfoString != aShapeInfoString) {
//outSha->destroy();
delete outSha;
nd4j_printf("Expected vs provided shapes mismatch %s vs %s at index %i with expected shape info %s and output shape info %s. Conditions, shapeEquals: %d, array empty: %d\n", eShape.c_str(), aShape.c_str(), idx,eShapeInfoString.c_str(),aShapeInfoString.c_str(),shapeEquals,arrayEmpty);
throw std::runtime_error("Output array did not match expected shape.");
}
2019-06-06 14:21:15 +02:00
}
}
}
}
if (!canUseFastPath)
ctx.forbidFastPath(true);
2019-06-06 14:21:15 +02:00
delete outSha;
// saving arrayTime
if (Environment::getInstance().isProfiling() && node != nullptr) {
2019-06-06 14:21:15 +02:00
arrayEnd = std::chrono::system_clock::now();
auto arrayTime = std::chrono::duration_cast<std::chrono::nanoseconds>(arrayEnd - arrayStart).count();
node->setArrayTime(arrayTime);
}
return results;
}
}
void sd::ops::DeclarableOp::storeResult(Context &block, int outputNumber, NDArray* array) {
2019-06-06 14:21:15 +02:00
this->storeResult(block, outputNumber, *array);
}
void sd::ops::DeclarableOp::storeResult(sd::graph::Context &ctx, int outputNumber, NDArray& array) {
2019-06-06 14:21:15 +02:00
ctx.pushNDArrayToVariableSpace(ctx.nodeId(), outputNumber, &array, !ctx.isInplace());
}
bool sd::ops::DeclarableOp::allocateResult(Context& block, Nd4jLong* shape) {
2019-06-06 14:21:15 +02:00
auto var = block.variable(block.getNodeId(), 0);
auto workspace = block.getWorkspace();
Nd4jLong len = shape::length(shape);
Nd4jLong* __shape;
ALLOCATE(__shape, workspace, shape::shapeInfoLength(shape), Nd4jLong); //new int[shape[0] * 2 + 4];
memcpy(__shape, shape, shape::shapeInfoByteLength(shape));
// if that's first run - we probably have nothing here
if (var->getNDArray() == nullptr) {
std::shared_ptr<DataBuffer> buffer = std::make_shared<DataBuffer>(len * sizeof(int8_t), ArrayOptions::dataType(__shape), workspace);
var->setNDArray(new NDArray(buffer, ShapeDescriptor(__shape), block.launchContext()));
}
else if(var->getNDArray()->lengthOf() != len) {
// if length not match - lets reallocate array
delete var->getNDArray();
std::shared_ptr<DataBuffer> buffer = std::make_shared<DataBuffer>(len * sizeof(int8_t), ArrayOptions::dataType(__shape), workspace);
var->setNDArray(new NDArray(buffer, ShapeDescriptor(__shape), block.launchContext()));
}
return true;
}
bool sd::ops::DeclarableOp::allocateResult(Context& block, std::initializer_list<Nd4jLong>& shape, char order) {
2019-06-06 14:21:15 +02:00
auto var = block.variable(block.getNodeId(), 0);
auto workspace = block.getWorkspace();
Nd4jLong len = shape::length(shape);
// if that's first run - we probably have nothing here
if (var->getNDArray() == nullptr) {
var->setNDArray(new NDArray(order, shape, block.dataType(), block.launchContext()));
} else if(var->getNDArray()->lengthOf() != len) {
// if length not match - lets reallocate array
delete var->getNDArray();
var->setNDArray(new NDArray(order, shape, block.dataType(), block.launchContext()));
}
return true;
}
Nd4jStatus sd::ops::DeclarableOp::validateDataTypes(Context& block) {
2019-06-06 14:21:15 +02:00
_registrator.lock();
if (!_registered) {
_registered = true;
this->registerTypes();
}
_registrator.unlock();
// rolling over inputs first
int cnt = 0, inT = 0;
std::vector<sd::DataType> inputTypes(block.width());
if (block.isFastPath()) {
for (auto array: block.fastpath_in()) {
if (array == nullptr)
continue;
2019-06-06 14:21:15 +02:00
inputTypes[inT++] = array->dataType();
if (!_descriptor->checkInputMatch(cnt, array->dataType())) {
auto ctype = DataTypeUtils::asString(array->dataType());
nd4j_printf("Op [%s] failed check for input [%i], DataType: [%s]\n",
_descriptor->getOpName()->data(), cnt, ctype.c_str());
2019-06-06 14:21:15 +02:00
return ND4J_STATUS_BAD_ARGUMENTS;
}
cnt++;
2019-06-06 14:21:15 +02:00
}
} else {
for (auto &p: *(block.inputs())) {
auto var = block.variable(p);
2019-06-06 14:21:15 +02:00
// we're not checking validity, if ANY types were explicitly allowed
//if (block.dataType(cnt) == sd::DataType::ANY)
// continue;
2019-06-06 14:21:15 +02:00
// only validating non-null variables
if (var != nullptr && var->hasNDArray()) {
auto array = var->getNDArray();
inputTypes[inT++] = array->dataType();
if (!_descriptor->checkInputMatch(cnt, array->dataType())) {
auto ctype = DataTypeUtils::asString(array->dataType());
nd4j_printf("Op [%s] failed check for input [%i], DataType: [%s]\n",
_descriptor->getOpName()->data(), cnt, ctype.c_str());
return ND4J_STATUS_BAD_ARGUMENTS;
}
}
2019-06-06 14:21:15 +02:00
cnt++;
}
}
2019-06-06 14:21:15 +02:00
if (block.isFastPath()) {
int index = 0;
for (auto array: block.fastpath_out()) {
if (array == nullptr)
continue;
auto cType = array->dataType();
2019-06-06 14:21:15 +02:00
if (_descriptor->isSameMode()) {
if (index >= block.width()) {
if (block.fastpath_in().size() == 0)
continue;
auto ia = block.fastpath_in()[0];
if (ia->dataType() != cType) {
2019-06-06 14:21:15 +02:00
auto t = DataTypeUtils::asString(cType);
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n",
_descriptor->getOpName()->data(), index, t.c_str());
2019-06-06 14:21:15 +02:00
return ND4J_STATUS_BAD_ARGUMENTS;
}
} else {
// for same mode, output type must be the same as input type
auto ia = block.fastpath_in()[index];
2019-06-06 14:21:15 +02:00
if (ia->dataType() != cType) {
auto t = DataTypeUtils::asString(cType);
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n",
_descriptor->getOpName()->data(), index, t.c_str());
return ND4J_STATUS_BAD_ARGUMENTS;
}
}
} else if (_descriptor->isInherit(index)) {
// in inherit mode, output type must be the same as one of input types
if (std::find(inputTypes.begin(), inputTypes.end(), cType) == inputTypes.end()) {
2019-06-06 14:21:15 +02:00
auto t = DataTypeUtils::asString(cType);
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s].\n",
_descriptor->getOpName()->data(), index, t.c_str());
2019-06-06 14:21:15 +02:00
return ND4J_STATUS_BAD_ARGUMENTS;
}
} else if (!_descriptor->checkOutputMatch(index, cType)) {
auto t = DataTypeUtils::asString(cType);
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s];\n",
_descriptor->getOpName()->data(), index, t.c_str());
return ND4J_STATUS_BAD_ARGUMENTS;
2019-06-06 14:21:15 +02:00
}
index++;
}
} else {
// checking optionally available outputs
auto varSpace = block.getVariableSpace();
for (int index = 0; index < DataTypeUtils::max<int>(); index++) {
if (varSpace != nullptr && varSpace->hasVariable(block.nodeId(), index)) {
auto var = block.variable(block.nodeId(), index);
// only validating non-null variables
if (var != nullptr && var->hasNDArray()) {
auto array = var->getNDArray();
auto cType = array->dataType();
if (_descriptor->isSameMode()) {
if (index >= block.width()) {
if (block.width() == 0)
continue;
auto iv = block.variable(0);
if (iv->getNDArray()->dataType() != cType) {
auto t = DataTypeUtils::asString(cType);
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n",
_descriptor->getOpName()->data(), index, t.c_str());
return ND4J_STATUS_BAD_ARGUMENTS;
}
} else {
// for same mode, output type must be the same as input type
auto iv = block.variable(index);
if (iv->getNDArray()->dataType() != cType) {
auto t = DataTypeUtils::asString(cType);
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n",
_descriptor->getOpName()->data(), index, t.c_str());
return ND4J_STATUS_BAD_ARGUMENTS;
}
}
} else if (_descriptor->isInherit(index)) {
// in inherit mode, output type must be the same as one of input types
if (std::find(inputTypes.begin(), inputTypes.end(), cType) == inputTypes.end()) {
auto t = DataTypeUtils::asString(cType);
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s].\n",
_descriptor->getOpName()->data(), index, t.c_str());
return ND4J_STATUS_BAD_ARGUMENTS;
}
} else if (!_descriptor->checkOutputMatch(index, cType)) {
auto t = DataTypeUtils::asString(cType);
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s];\n",
_descriptor->getOpName()->data(), index, t.c_str());
return ND4J_STATUS_BAD_ARGUMENTS;
}
}
} else
break;
}
2019-06-06 14:21:15 +02:00
}
return ND4J_STATUS_OK;
}
Nd4jStatus sd::ops::DeclarableOp::execute(Context* block) {
2019-06-06 14:21:15 +02:00
nd4j_debug("Executing op: [%s]\n", this->getOpName()->c_str());
std::chrono::time_point<std::chrono::system_clock> timeEnter, timeStart, timeEnd;
Nd4jLong prepTime, outerTime;
Nd4jLong memoryBefore = block->workspace() == nullptr ? 0L : block->workspace()->getSpilledSize() + block->workspace()->getUsedSize();
if (Environment::getInstance().isProfiling())
2019-06-06 14:21:15 +02:00
timeEnter = std::chrono::system_clock::now();
// basic validation: ensure inputs are set
REQUIRE_OK(this->validateNonEmptyInput(*block));
// ensure number of IArgs, TArgs match our expectations
REQUIRE_OK(this->validateArguments(*block));
// validating data types for inputs and (optionally) outputs
REQUIRE_OK(this->validateDataTypes(*block));
// this method will allocate output NDArrays for this op
auto numOutputs = this->prepareOutputs(*block);
if (Environment::getInstance().isProfiling()) {
2019-06-06 14:21:15 +02:00
timeStart = std::chrono::system_clock::now();
prepTime = std::chrono::duration_cast<std::chrono::nanoseconds>(timeStart - timeEnter).count();
}
Platform helpers (#8216) * platform helpers draft Signed-off-by: raver119 <raver119@gmail.com> * typo Signed-off-by: raver119 <raver119@gmail.com> * disable platform cmake Signed-off-by: raver119 <raver119@gmail.com> * another draft Signed-off-by: raver119 <raver119@gmail.com> * mkldnn convolution refactored Signed-off-by: raver119 <raver119@gmail.com> * minor tweaks Signed-off-by: raver119 <raver119@gmail.com> * one more safety check Signed-off-by: raver119 <raver119@gmail.com> * prototype works Signed-off-by: raver119 <raver119@gmail.com> * meh Signed-off-by: raver119 <raver119@gmail.com> * force static library mode for mkldnn Signed-off-by: raver119 <raver119@gmail.com> * - ismax fix - experimental arg fix - don't enforce openblas on Apple hardware Signed-off-by: raver119 <raver119@gmail.com> * bunch of small fixes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * declare concurrent Signed-off-by: raver119@gmail.com <raver119@gmail.com> * - MKLDNN version upgrade to 1.0.2 - avgpool2d/maxpool2d APIs update Signed-off-by: raver119 <raver119@gmail.com> * - avgpool2d_bp/maxpool2d_bp APIs update Signed-off-by: raver119 <raver119@gmail.com> * - conv2d/batchnorm APIs update Signed-off-by: raver119 <raver119@gmail.com> * - lrn/conv2d_bp/conv3d/conv3d_bp APIs update Signed-off-by: raver119 <raver119@gmail.com> * all ops converted to MKLDNN 1.x Signed-off-by: raver119 <raver119@gmail.com> * bunch of tweaks Signed-off-by: raver119 <raver119@gmail.com> * namespace for platform helpers Signed-off-by: raver119 <raver119@gmail.com> * make sure platform helpers aren't opimized out Signed-off-by: raver119 <raver119@gmail.com> * build cpu_features on x86 systems Signed-off-by: raver119 <raver119@gmail.com> * build cpu_features on x86 systems Signed-off-by: raver119 <raver119@gmail.com> * more of cpu_features Signed-off-by: raver119 <raver119@gmail.com> * - mkldnn removed from java - cpu_features checks in CpuNDArrayFactory Signed-off-by: raver119 <raver119@gmail.com> * F16C definition renamed Signed-off-by: raver119 <raver119@gmail.com> * some mkldnn rearrangements Signed-off-by: raver119 <raver119@gmail.com> * check supported instructions before doing anything Signed-off-by: raver119 <raver119@gmail.com> * typo Signed-off-by: raver119 <raver119@gmail.com> * missied impl Signed-off-by: raver119 <raver119@gmail.com> * BUILD_PIC option Signed-off-by: raver119 <raver119@gmail.com> * conv2d fix Signed-off-by: raver119 <raver119@gmail.com> * avgpool3d fix Signed-off-by: raver119 <raver119@gmail.com> * avgpool3d_bp fix Signed-off-by: raver119 <raver119@gmail.com> * avgpool2d_bp leak fix Signed-off-by: raver119 <raver119@gmail.com> * avgpool3d_bp leak fix Signed-off-by: raver119 <raver119@gmail.com> * maxpool bp leaks fixed Signed-off-by: raver119 <raver119@gmail.com> * printf removed Signed-off-by: raver119 <raver119@gmail.com> * batchnorm fix Signed-off-by: raver119 <raver119@gmail.com> * AVX warning/error polishing Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * More polish Signed-off-by: AlexDBlack <blacka101@gmail.com> * Polish Signed-off-by: AlexDBlack <blacka101@gmail.com> * remove previous MKL-DNN support layer Signed-off-by: raver119 <raver119@gmail.com> * avx2 tweak Signed-off-by: raver119 <raver119@gmail.com> * allow static for apple Signed-off-by: raver119@gmail.com <raver119@gmail.com> * exclude mkldnn in one more place Signed-off-by: raver119 <raver119@gmail.com> * exclude mkldnn in one more place Signed-off-by: raver119 <raver119@gmail.com> * restore OPENBLAS_PATH use Signed-off-by: raver119 <raver119@gmail.com> * add runtime check for avx/avx2 support Signed-off-by: raver119 <raver119@gmail.com> * convolution_auto Signed-off-by: raver119 <raver119@gmail.com> * Add logic for helper argument * minor test fix Signed-off-by: raver119 <raver119@gmail.com> * few tweaks Signed-off-by: raver119 <raver119@gmail.com> * few tweaks Signed-off-by: raver119 <raver119@gmail.com> * skip OpTracker props for non-x86 builds Signed-off-by: raver119 <raver119@gmail.com> * linux arm isn't x86 :) Signed-off-by: raver119 <raver119@gmail.com> * avx-512 Signed-off-by: raver119 <raver119@gmail.com> * CUDA presets fix Signed-off-by: raver119 <raver119@gmail.com> * BUILD_PIC Signed-off-by: raver119 <raver119@gmail.com> * prefetchw for avx2 Signed-off-by: raver119 <raver119@gmail.com> * BUILD_PIC again Signed-off-by: raver119 <raver119@gmail.com>
2019-09-11 20:50:28 +02:00
Nd4jStatus status;
bool hasHelper = false;
// platform helpers use might be forbidden for various reasons, so we'll check it out first
if (block->helpersAllowed() && sd::Environment::getInstance().helpersAllowed()) {
// if we have platform-specific helper for this op - invoke it
if (OpRegistrator::getInstance().hasHelper(this->getOpHash(), block->engine())) {
auto helper = OpRegistrator::getInstance().getPlatformHelper(this->getOpHash(), block->engine());
if (helper->isUsable(*block)) {
status = helper->invokeHelper(*block);
hasHelper = true;
}
Platform helpers (#8216) * platform helpers draft Signed-off-by: raver119 <raver119@gmail.com> * typo Signed-off-by: raver119 <raver119@gmail.com> * disable platform cmake Signed-off-by: raver119 <raver119@gmail.com> * another draft Signed-off-by: raver119 <raver119@gmail.com> * mkldnn convolution refactored Signed-off-by: raver119 <raver119@gmail.com> * minor tweaks Signed-off-by: raver119 <raver119@gmail.com> * one more safety check Signed-off-by: raver119 <raver119@gmail.com> * prototype works Signed-off-by: raver119 <raver119@gmail.com> * meh Signed-off-by: raver119 <raver119@gmail.com> * force static library mode for mkldnn Signed-off-by: raver119 <raver119@gmail.com> * - ismax fix - experimental arg fix - don't enforce openblas on Apple hardware Signed-off-by: raver119 <raver119@gmail.com> * bunch of small fixes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * declare concurrent Signed-off-by: raver119@gmail.com <raver119@gmail.com> * - MKLDNN version upgrade to 1.0.2 - avgpool2d/maxpool2d APIs update Signed-off-by: raver119 <raver119@gmail.com> * - avgpool2d_bp/maxpool2d_bp APIs update Signed-off-by: raver119 <raver119@gmail.com> * - conv2d/batchnorm APIs update Signed-off-by: raver119 <raver119@gmail.com> * - lrn/conv2d_bp/conv3d/conv3d_bp APIs update Signed-off-by: raver119 <raver119@gmail.com> * all ops converted to MKLDNN 1.x Signed-off-by: raver119 <raver119@gmail.com> * bunch of tweaks Signed-off-by: raver119 <raver119@gmail.com> * namespace for platform helpers Signed-off-by: raver119 <raver119@gmail.com> * make sure platform helpers aren't opimized out Signed-off-by: raver119 <raver119@gmail.com> * build cpu_features on x86 systems Signed-off-by: raver119 <raver119@gmail.com> * build cpu_features on x86 systems Signed-off-by: raver119 <raver119@gmail.com> * more of cpu_features Signed-off-by: raver119 <raver119@gmail.com> * - mkldnn removed from java - cpu_features checks in CpuNDArrayFactory Signed-off-by: raver119 <raver119@gmail.com> * F16C definition renamed Signed-off-by: raver119 <raver119@gmail.com> * some mkldnn rearrangements Signed-off-by: raver119 <raver119@gmail.com> * check supported instructions before doing anything Signed-off-by: raver119 <raver119@gmail.com> * typo Signed-off-by: raver119 <raver119@gmail.com> * missied impl Signed-off-by: raver119 <raver119@gmail.com> * BUILD_PIC option Signed-off-by: raver119 <raver119@gmail.com> * conv2d fix Signed-off-by: raver119 <raver119@gmail.com> * avgpool3d fix Signed-off-by: raver119 <raver119@gmail.com> * avgpool3d_bp fix Signed-off-by: raver119 <raver119@gmail.com> * avgpool2d_bp leak fix Signed-off-by: raver119 <raver119@gmail.com> * avgpool3d_bp leak fix Signed-off-by: raver119 <raver119@gmail.com> * maxpool bp leaks fixed Signed-off-by: raver119 <raver119@gmail.com> * printf removed Signed-off-by: raver119 <raver119@gmail.com> * batchnorm fix Signed-off-by: raver119 <raver119@gmail.com> * AVX warning/error polishing Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * More polish Signed-off-by: AlexDBlack <blacka101@gmail.com> * Polish Signed-off-by: AlexDBlack <blacka101@gmail.com> * remove previous MKL-DNN support layer Signed-off-by: raver119 <raver119@gmail.com> * avx2 tweak Signed-off-by: raver119 <raver119@gmail.com> * allow static for apple Signed-off-by: raver119@gmail.com <raver119@gmail.com> * exclude mkldnn in one more place Signed-off-by: raver119 <raver119@gmail.com> * exclude mkldnn in one more place Signed-off-by: raver119 <raver119@gmail.com> * restore OPENBLAS_PATH use Signed-off-by: raver119 <raver119@gmail.com> * add runtime check for avx/avx2 support Signed-off-by: raver119 <raver119@gmail.com> * convolution_auto Signed-off-by: raver119 <raver119@gmail.com> * Add logic for helper argument * minor test fix Signed-off-by: raver119 <raver119@gmail.com> * few tweaks Signed-off-by: raver119 <raver119@gmail.com> * few tweaks Signed-off-by: raver119 <raver119@gmail.com> * skip OpTracker props for non-x86 builds Signed-off-by: raver119 <raver119@gmail.com> * linux arm isn't x86 :) Signed-off-by: raver119 <raver119@gmail.com> * avx-512 Signed-off-by: raver119 <raver119@gmail.com> * CUDA presets fix Signed-off-by: raver119 <raver119@gmail.com> * BUILD_PIC Signed-off-by: raver119 <raver119@gmail.com> * prefetchw for avx2 Signed-off-by: raver119 <raver119@gmail.com> * BUILD_PIC again Signed-off-by: raver119 <raver119@gmail.com>
2019-09-11 20:50:28 +02:00
}
}
// if we don't have platform-specific helper - invoke generic implementation
if (!hasHelper)
status = this->validateAndExecute(*block);
2019-06-06 14:21:15 +02:00
// optionally saving execution time
if (Environment::getInstance().isProfiling()) {
2019-06-06 14:21:15 +02:00
timeEnd = std::chrono::system_clock::now();
outerTime = std::chrono::duration_cast<std::chrono::nanoseconds>(timeEnd - timeStart).count();
block->setInnerTime(outerTime);
}
if (Environment::getInstance().isProfiling() && block->getVariableSpace() != nullptr) {
2019-06-06 14:21:15 +02:00
auto fp = block->getVariableSpace()->flowPath();
if (fp != nullptr) {
auto p = fp->profile();
if (p != nullptr) {
Nd4jLong memoryAfter = block->workspace() == nullptr ? 0L : block->workspace()->getSpilledSize() + block->workspace()->getUsedSize();
Nd4jLong memoryUsed = memoryAfter - memoryBefore;
p->nodeById(block->nodeId())->setPreparationTime(prepTime);
p->nodeById(block->nodeId())->setExecutionTime(outerTime);
p->nodeById(block->nodeId())->setTotalSize(memoryUsed);
}
}
}
// now we print out all outputs for this node
if (sd::Environment::getInstance().isDebugAndVerbose()) {
2019-06-06 14:21:15 +02:00
auto vs = block->getVariableSpace();
for (int e = 0; e < numOutputs; e++) {
// if given output index doesn't exist - we're done
if (!block->isFastPath()) {
if (!vs->hasVariable(block->nodeId(), e))
break;
} else {
// we have to check either in or out stack, depending on isInplace()
if (block->isInplace()) {
if (block->fastpath_in().size() <= e)
break;
} else {
if (block->fastpath_out().size() <= e)
break;
}
}
auto array = block->isFastPath() ? block->isInplace() ? block->fastpath_in()[e] : block->fastpath_out()[e] : vs->getVariable(block->nodeId(), e)->getNDArray();
auto shape = ShapeUtils::shapeAsString(array);
auto first = array->isEmpty() ? std::string("Empty NDArray") : array->asString(32);
auto type = DataTypeUtils::asString(array->dataType());
nd4j_printf("node_%i:%i result shape: %s; dtype: %s; first values %s\n", block->nodeId(), e, shape.c_str(), type.c_str(), first.c_str());
}
}
return status;
}
void DeclarableOp::overwriteResult(Context &block, int outputIdx, NDArray *array) {
throw std::runtime_error("Overwrite result used!");
//block.pushNDArrayToVariableSpace(block.nodeId(), outputIdx, array);
/*
auto varSpace = block.getVariableSpace();
if (varSpace->hasVariable(block.getNodeId(), outputIdx)) {
auto var = varSpace->getVariable(block.getNodeId(), outputIdx);
if (var->getNDArray() != nullptr && var->isRemovable())
delete var->getNDArray();
var->setNDArray(array);
var->markRemovable(true);
} else {
auto var = new Variable(array, nullptr, block.getNodeId(), outputIdx);
varSpace->putVariable(block.getNodeId(), outputIdx, var);
}
*/
}
void DeclarableOp::overwriteResult(Context &block, int outputIdx, NDArrayList *list) {
throw std::runtime_error("Overwrite result used!");
//block.pushNDArrayListToVariableSpace(block.nodeId(), outputIdx, list);
/*
auto varSpace = block.getVariableSpace();
if (varSpace->hasVariable(block.getNodeId(), outputIdx)) {
auto var = varSpace->getVariable(block.getNodeId(), outputIdx);
var->setNDArrayList(list);
} else {
auto var = new Variable(nullptr, nullptr, block.getNodeId(), outputIdx);
var->setNDArrayList(list);
varSpace->putVariable(block.getNodeId(), outputIdx, var);
}
*/
}
Nd4jStatus sd::ops::DeclarableOp::validateArguments(Context& block) {
2019-06-06 14:21:15 +02:00
/*
* We're checking number of T and I arguments. If number of args is finite number - we check strict equality
* If number of args is variable (-1), but variables MUST be present - we check for non-zero number of arguments
*/
if (_descriptor->getNumberOfTArgs() > 0) {
if ((int) block.getTArguments()->size() < _descriptor->getNumberOfTArgs()) {
nd4j_printf("%s: %i T args expected, but %i received\n", this->getOpName()->c_str(), _descriptor->getNumberOfTArgs(), block.getTArguments()->size());
return ND4J_STATUS_BAD_PARAMS;
}
} else
if (_descriptor->getNumberOfTArgs() == -1)
if (block.getTArguments()->size() == 0) {
nd4j_printf("%s: Number of T arguments should be positive number, but got 0 arguments\n", this->getOpName()->c_str());
return ND4J_STATUS_BAD_PARAMS;
}
if (_descriptor->getNumberOfIArgs() > 0) {
if ((int) block.getIArguments()->size() < _descriptor->getNumberOfIArgs()) {
nd4j_printf("%s: %i int args expected, but %i received\n", this->getOpName()->c_str(), _descriptor->getNumberOfIArgs(), block.getIArguments()->size());
return ND4J_STATUS_BAD_PARAMS;
}
} else
if (_descriptor->getNumberOfIArgs() == -1)
if (block.getIArguments()->size() == 0) {
nd4j_printf("%s: Number of Integer arguments should be positive number, but got 0 arguments\n", this->getOpName()->c_str());
return ND4J_STATUS_BAD_PARAMS;
}
return ND4J_STATUS_OK;
}
Nd4jStatus sd::ops::DeclarableOp::validateInputDimensions(Context& block, int rank) {
2019-06-06 14:21:15 +02:00
if (block.width() == 0)
return ND4J_STATUS_OK;
for (auto p: *block.inputs()) {
auto v = block.variable(p);
NDArray *aV = v->getNDArray();
if (aV == nullptr)
return ND4J_STATUS_BAD_INPUT;
if (aV->rankOf() != rank)
return ND4J_STATUS_BAD_DIMENSIONS;
}
return ND4J_STATUS_OK;
}
Nd4jStatus sd::ops::DeclarableOp::validateInput2D(Context& block) {
2019-06-06 14:21:15 +02:00
return validateInputDimensions(block, 2);
}
Nd4jStatus sd::ops::DeclarableOp::validateInput3D(Context& block) {
2019-06-06 14:21:15 +02:00
return validateInputDimensions(block, 3);
}
Nd4jStatus sd::ops::DeclarableOp::validateInput4D(Context& block) {
2019-06-06 14:21:15 +02:00
return validateInputDimensions(block, 4);
}
Nd4jStatus sd::ops::DeclarableOp::validateNonEmptyInput(Context& block) {
2019-06-06 14:21:15 +02:00
if (this->getOpDescriptor()->getNumberOfInputs() == -2 || this->getOpDescriptor()->getNumberOfInputs() == 0)
return Status::OK();
if (block.width() < 1) {
nd4j_printf("%s: no operands provided for the op", this->getOpName()->c_str());
return ND4J_STATUS_BAD_INPUT;
}
int cnt = 0;
for (auto p: *block.inputs()) {
auto v = block.variable(p);
if (v == nullptr) {
if (this->getOpName() != nullptr) {
nd4j_printf("Node [%i:<%s>]: Variable [%i] (%i:%i) is NULL\n", block.getNodeId(), this->getOpName()->c_str(), cnt, p.first, p.second);
} else {
nd4j_printf("Node [%i:<noname>]: Variable [%i] (%i:%i) is NULL\n", block.getNodeId(), cnt, p.first, p.second);
}
return ND4J_STATUS_BAD_INPUT;
}
if (v->variableType() == VariableType::NDARRAY) {
NDArray *aV = v->getNDArray();
// if array is empty intentionally - we're ok with that
if (v->hasNDArray() && v->isEmpty())
continue;
if (aV == nullptr || !aV->nonNull()) {
if (this->getOpName() != nullptr) {
nd4j_printf("Node [%i:<%s>]: NDArray [%i] (%i:%i) is NULL\n", block.getNodeId(), this->getOpName()->c_str(), cnt, p.first, p.second);
} else {
nd4j_printf("Node [%i:<noname>]: NDArray [%i] (%i:%i) is NULL\n", block.getNodeId(), cnt, p.first, p.second);
}
return ND4J_STATUS_BAD_INPUT;
}
}
cnt++;
}
return ND4J_STATUS_OK;
}
Nd4jStatus sd::ops::DeclarableOp::validateOrdersMatch(Context& block) {
2019-06-06 14:21:15 +02:00
if (block.width() == 0)
return ND4J_STATUS_OK;
NDArray *a0 = block.variable(0)->getNDArray();
for (auto p: *block.inputs()) {
auto v = block.variable(p);
NDArray *aV = v->getNDArray();
if (a0->ordering() != aV->ordering())
return ND4J_STATUS_BAD_ORDER;
}
return ND4J_STATUS_OK;
}
Nd4jStatus sd::ops::DeclarableOp::execute(sd::graph::RandomGenerator& rng, const std::vector<NDArray*>& inputs, const std::vector<NDArray*>& outputs, const std::vector<double>& tArgs, const std::vector<Nd4jLong>& iArgs, const std::vector<bool>& bArgs, const std::vector<sd::DataType>& dArgs, bool isInplace, sd::DataType type) {
2019-06-06 14:21:15 +02:00
VariableSpace variableSpace;
FlowPath fp;
variableSpace.setFlowPath(&fp);
int cnt = -1;
std::vector<int> in;
for (auto v: inputs) {
if (v == nullptr)
continue;
auto var = new Variable(v);
var->markRemovable(false);
in.push_back(cnt);
variableSpace.putVariable(cnt--, var);
}
int et = 0;
for (auto v: outputs) {
auto var = new Variable(v);
var->markRemovable(false);
std::pair<int,int> pair(1, et++);
variableSpace.putVariable(pair, var);
}
Context block(1, &variableSpace, false);
block.fillInputs(in);
block.markInplace(isInplace);
block.setDataType(0, type);
// we need this line for tests basically
//if (rng != nullptr)
block.setRng(rng);
for (int e = 0; e < tArgs.size(); e++)
block.getTArguments()->emplace_back(tArgs.at(e));
// FIXME: iargs should be Nd4jLong
for (int e = 0; e < iArgs.size(); e++)
block.getIArguments()->emplace_back(static_cast<int>(iArgs.at(e)));
for (int e = 0; e < bArgs.size(); e++)
block.getBArguments()->push_back(static_cast<int>(bArgs.at(e)));
for (int e = 0; e < dArgs.size(); e++)
block.getDArguments()->push_back(dArgs.at(e));
2019-06-06 14:21:15 +02:00
Nd4jStatus result = this->execute(&block);
return result;
}
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs) {
return execute(inputs, outputs, std::vector<double>(), std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<sd::DataType>());
}
template <>
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<double> tArgs) {
return execute(inputs, outputs, tArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<sd::DataType>());
}
template <>
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<sd::DataType> dArgs) {
return execute(inputs, outputs, std::vector<double>(), std::vector<Nd4jLong>(), std::vector<bool>(), dArgs);
}
template <>
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<float> tArgs) {
std::vector<double> realArgs;
for (auto v:tArgs)
realArgs.emplace_back(v);
return execute(inputs, outputs, realArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<sd::DataType>());
}
template <>
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<Nd4jLong> iArgs) {
return execute(inputs, outputs, std::vector<double>(), iArgs, std::vector<bool>(), std::vector<sd::DataType>());
}
template <>
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<int> iArgs) {
std::vector<Nd4jLong> realArgs;
for (auto v:iArgs)
realArgs.emplace_back(v);
return execute(inputs, outputs, std::vector<double>(), realArgs, std::vector<bool>(), std::vector<sd::DataType>());
}
template <>
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<bool> bArgs) {
return execute(inputs, outputs, std::vector<double>(), std::vector<Nd4jLong>(), bArgs, std::vector<sd::DataType>());
}
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs, const std::vector<sd::DataType> &dArgs, bool isInplace) {
Context ctx(1);
for (int e = 0; e < inputs.size(); e++) {
ctx.setInputArray(e, inputs[e]);
}
for (int e = 0; e < outputs.size(); e++) {
ctx.setOutputArray(e, outputs[e]);
}
if (isInplace)
ctx.markInplace(isInplace);
ctx.setIArguments(iArgs);
ctx.setTArguments(tArgs);
ctx.setBArguments(bArgs);
ctx.setDArguments(dArgs);
return execute(&ctx);
}
sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs) {
return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<sd::DataType>());
}
template <>
sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<int> iArgs) {
std::vector<Nd4jLong> realArgs;
for (auto v:iArgs)
realArgs.emplace_back(v);
return evaluate(inputs, std::vector<double>(), realArgs, std::vector<bool>(), std::vector<sd::DataType>());
}
template <>
sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<Nd4jLong> iArgs) {
return evaluate(inputs, std::vector<double>(), iArgs, std::vector<bool>(), std::vector<sd::DataType>());
}
template <>
sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<float> tArgs) {
std::vector<double> realArgs;
for (auto v:tArgs)
realArgs.emplace_back(v);
return evaluate(inputs, realArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<sd::DataType>());
}
template <>
sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<double> tArgs) {
return evaluate(inputs, tArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<sd::DataType>());
}
template <>
sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<bool> bArgs) {
return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), bArgs, std::vector<sd::DataType>());
}
template <>
sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<sd::DataType> bArgs) {
return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), std::vector<bool>(), bArgs);
}
sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs, const std::vector<sd::DataType> &dArgs, bool isInplace) {
2019-06-06 14:21:15 +02:00
VariableSpace variableSpace;
//ResultSet arrayList;
FlowPath fp;
variableSpace.setFlowPath(&fp);
int cnt = -1;
std::vector<int> in;
for (auto v: inputs) {
if (v == nullptr)
continue;
auto var = new Variable(v);
var->markRemovable(false);
in.push_back(cnt);
variableSpace.putVariable(cnt--, var);
}
Context block(1, &variableSpace, false);
block.setDataType(0, sd::DataType::FLOAT32);
2019-06-06 14:21:15 +02:00
block.fillInputs(in);
block.markInplace(isInplace);
// block.setRNG(ProviderRNG::getInstance().getRNG());
2019-06-06 14:21:15 +02:00
for (int e = 0; e < tArgs.size(); e++)
block.getTArguments()->emplace_back(tArgs.at(e));
for (int e = 0; e < iArgs.size(); e++)
block.getIArguments()->emplace_back(iArgs.at(e));
for (int e = 0; e < bArgs.size(); e++)
block.getBArguments()->push_back(bArgs.at(e));
for (int e = 0; e < dArgs.size(); e++)
block.getDArguments()->push_back(dArgs.at(e));
2019-06-06 14:21:15 +02:00
Nd4jStatus status = this->execute(&block);
ResultSet arrayList;
2019-06-06 14:21:15 +02:00
if (isInplace)
arrayList.setNonRemovable();
2019-06-06 14:21:15 +02:00
arrayList.setStatus(status);
2019-06-06 14:21:15 +02:00
if (status != ND4J_STATUS_OK)
return arrayList;
if (!isInplace) {
for (int e = 0; e < DataTypeUtils::max<int>(); e++) {
std::pair<int, int> pair(1, e);
if (variableSpace.hasVariable(pair)) {
auto var = variableSpace.getVariable(pair);
auto arr = var->getNDArray();
if (!arr->isAttached()) {
var->markRemovable(false);
arr->setContext(sd::LaunchContext::defaultContext());
arrayList.push_back(arr);
} else {
arrayList.push_back(arr->detach());
}
} else
break;
}
} else {
for (auto v:inputs) {
arrayList.push_back(v);
}
2019-06-06 14:21:15 +02:00
}
return arrayList;
}
sd::ResultSet sd::ops::DeclarableOp::execute(const sd::OpArgsHolder& holder, bool isInplace) {
// FIXME: add DArgs to OpArgsHolder
return evaluate(holder.getInArrs(), holder.getTArgs(), holder.getIArgs(), holder.getBArgs(), std::vector<sd::DataType>(), isInplace);
2019-06-06 14:21:15 +02:00
}
Nd4jStatus sd::ops::DeclarableOp::validateInputDimensionsMatch(Context& block) {
2019-06-06 14:21:15 +02:00
if (block.width() == 0)
return ND4J_STATUS_OK;
NDArray *a0 = block.array(0);
for (int e = 1; e < block.width(); e++) {
2019-06-06 14:21:15 +02:00
auto aV = block.array(e);
Legacy API changes (#441) * initial commit Signed-off-by: raver119@gmail.com <raver119@gmail.com> * another initial commit Signed-off-by: raver119@gmail.com <raver119@gmail.com> * another initial commit Signed-off-by: raver119@gmail.com <raver119@gmail.com> * one more initial commit Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next step Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next step Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next step Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next step Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Refactored buffer() and shapeInfo() methods usage with NDArray class. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt Graph class methods to use const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt choose op to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt where op shape method to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt lstsq op to use constant empty shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt matrix_diag_part op shape routine to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt determinant ops to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt mean_pairwssqerr_loss ops to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt ops shape methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt shape methods for loss ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt log_loss op shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt shape methods for ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt dilation2d ops shape methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted deconv2d ops shape methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted dynamicRNN op shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape methods for ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape methods for lstm layer ops. Signed-off-by: shugeo <sgazeos@gmail.com> * few updates Signed-off-by: raver119@gmail.com <raver119@gmail.com> * first cuda tweak Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Adopt constant shapes for sconv2d ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt constant shapes for gru ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt constant shapes with shape methods for segment ops and so on. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted constant shapes with unsorted_segment_* ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted constant shapes with gamma op shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape methods of reduce_stddev ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape methods for reduce_* ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt shape method for squeeze op. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt strided_slice shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored concat op shape method to adopt constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape method for mirror_pad op. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted split op shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted tile ops shape methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Added const cast for mkldnn routines handles. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored logSoftMaxForVector_ routine to conform with proper data and shape pointer casts. Signed-off-by: shugeo <sgazeos@gmail.com> * Cosmetic changes to proper usage of constant pointers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored a couple shape comparators for strides and addBias helpers to proper use data pointers with inplace option. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored depthToSpace helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored histogram helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored im2col helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored gather and gatherND helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage on percentile helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed gather shape with helpers and range buffer usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with space to depth helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage and constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with LUP decomposition> Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored onehot_ helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored pad and prefix to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactoed softmax helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed space to batch helpers to use buffers properly. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed stack and split helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with sparse to dense helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with mindistance_ helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with tile helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed constant shape usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed constant shape usage with legacy pairwise bool ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored a couple of methods to adopt constant shape usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed broadcasting with constant shape." Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const usage with inplace reverse and constant shapes with legacy reduction. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored legacy ops with const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored sort to adopt constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected sort for constant shape usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed constant shape usage with special methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored Context to conform with constant shape usage. Signed-off-by: shugeo <sgazeos@gmail.com> * CUDA broadcasting headers Signed-off-by: raver119@gmail.com <raver119@gmail.com> * pairwise/indexreduce/random headers Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Refactored native ops to adopt constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * legacy reduce3/scalar headers Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Corrected pullRow signature and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected routines to proper use of constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored tests to use constant shapes properly. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored legacy ops tests to use constant shapes properly. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored buffer usage with NDArray tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed native ops tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed special concat routine. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with test. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with a test. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored TAD.h and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored calcStrides* routines to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed miscelaneous errors with constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * NativeOps const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Corrected definitions for declared functions. Signed-off-by: shugeo <sgazeos@gmail.com> * NativeOps const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * few more const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Fixed const shapes with shape routines. Signed-off-by: shugeo <sgazeos@gmail.com> * few more const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Fixed shape method for broadcastable case. Signed-off-by: shugeo <sgazeos@gmail.com> * few more const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * xw_plus_b BP shape fn restored Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Fixed signatures with broadcasting. Signed-off-by: shugeo <sgazeos@gmail.com> * Repaired backprops shape methods for a set of operations. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored broadcast bool for cuda. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored methods for 3 args with const qualifier. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed a couple of kernel signatures for broadcasting. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed kernels signatures for const buffers and shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored pairwise methods to persistent buffers and shapes usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt const to buffers and shapes with kernels. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt const to buffers and shapes with scalar kernels. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored indexreduce kernels signatures to use const buffers and shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored pairwise kernels to adopt cons shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored pairwise bool kernels to adopt cons shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored random special ops to conform with const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored native ops to conform with const shapes and buffers under cuda platform. Signed-off-by: shugeo <sgazeos@gmail.com> * Cosmetical changes only. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shapes and buffers error. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected start pos routine. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored methods to conform with const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored helpers to use proper methods instead. Signed-off-by: shugeo <sgazeos@gmail.com> * bunch of changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next bunch of changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next bunch of changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Fixed execScalar declaration. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed execScalar declaration. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected const shape cases with sort and so on. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shapes for sort. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored kernel declarations to adopt const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed kernels declarations to adopt const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected kernel declarations to adopt const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed kernels declarations to adopt const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed segment helpers kernels declarations and so on to adopt const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shape usage with segment and solve helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed kernel declaration with adjustWeight helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed cuda implementations for constant shape helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted const shape usage with kernels. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted top_k kernels to use const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected kernels declarations to adopt const shapes with helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored NDArray definitions to adopt const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shapes with image suppression helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Slight improvement with buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored buffer usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored buffer usage with tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shape usage with definitions. Signed-off-by: shugeo <sgazeos@gmail.com> * minor updates on cpu side Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Refactored const shape usage with ConstantDescritor and native ops with cuda platform. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored tear and tile kernels to adopt with const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * softmax_loop fix Signed-off-by: raver119 <raver119@gmail.com> * update missing signature Signed-off-by: raver119@gmail.com <raver119@gmail.com> * softmax again Signed-off-by: raver119@gmail.com <raver119@gmail.com> * few more missing consts Signed-off-by: raver119 <raver119@gmail.com> * new methods updated Signed-off-by: raver119@gmail.com <raver119@gmail.com> Co-authored-by: shugeo <sgazeos@gmail.com>
2020-05-09 07:06:14 +02:00
if (!shape::equalsSoft(a0->shapeInfo(), aV->shapeInfo()))
2019-06-06 14:21:15 +02:00
return ND4J_STATUS_BAD_DIMENSIONS;
}
return ND4J_STATUS_OK;
}
Nd4jStatus sd::ops::DeclarableOp::validateInputLengthMatch(Context& block) {
2019-06-06 14:21:15 +02:00
if (block.width() == 0)
return ND4J_STATUS_OK;
Nd4jLong l0 = block.array(0)->lengthOf();
for (uint32_t e = 0; e < block.width(); e++) {
if (l0 != block.array(e)->lengthOf())
return ND4J_STATUS_BAD_LENGTH;
}
return ND4J_STATUS_OK;
}
samediff::EmptyHandling DeclarableOp::emptyHandling() {
return samediff::EmptyHandling::EMPTY_SKIP;
}
2019-06-06 14:21:15 +02:00
void DeclarableOp::registerTypes() {
this->getOpDescriptor()->setSameMode(true);
}
/*
template <typename T>
int* sd::ops::DeclarableOp::calculateOutputShape(int* inputShape, sd::graph::Block& block) {
2019-06-06 14:21:15 +02:00
// default implementation suits transform, so just returns the same shape
int* newshape;
ALLOCATE(newshape, block.getWorkspace(), shape::shapeInfoLength(inputShape), int);
memcpy(newshape, inputShape, shape::shapeInfoByteLength(inputShape));
return newshape;
}
*/
}
}