raver119 320924278d
Legacy API changes (#441)
* initial commit

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* another initial commit

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* another initial commit

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* one more initial commit

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* next step

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* next step

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* next step

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* next step

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* Refactored buffer() and shapeInfo() methods usage with NDArray class.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt Graph class methods to use const shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt choose op to use constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt where op shape method to use constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt lstsq op to use constant empty shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt matrix_diag_part op shape routine to use constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt determinant ops to use constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt mean_pairwssqerr_loss ops to use constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt ops shape methods.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt shape methods for loss ops.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt log_loss op shape method.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt shape methods for ops.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt dilation2d ops shape methods.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted deconv2d ops shape methods.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted dynamicRNN op shape method.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted shape methods for ops.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted shape methods for lstm layer ops.

Signed-off-by: shugeo <sgazeos@gmail.com>

* few updates

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* first cuda tweak

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* Adopt constant shapes for sconv2d ops.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt constant shapes for gru ops.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt constant shapes with shape methods for segment ops and so on.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted constant shapes with unsorted_segment_* ops.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted constant shapes with gamma op shape method.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted shape methods of reduce_stddev ops.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted shape methods for reduce_* ops.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt shape method for squeeze op.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt strided_slice shape method.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored concat op shape method to adopt constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted shape method for mirror_pad op.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted split op shape method.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted tile ops shape methods.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Added const cast for mkldnn routines handles.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored logSoftMaxForVector_ routine to conform with proper data and shape pointer casts.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Cosmetic changes to proper usage of constant pointers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored a couple shape comparators for strides and addBias helpers to proper use data pointers with inplace option.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored depthToSpace helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored histogram helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored im2col helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored gather and gatherND helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed buffer usage on percentile helper.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed gather shape with helpers and range buffer usage.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed buffer usage with space to depth helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed buffer usage and constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed buffer usage with LUP decomposition>

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored onehot_ helper.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored pad and prefix to use constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactoed softmax helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed space to batch helpers to use buffers properly.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed stack and split helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed buffer usage with sparse to dense helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed buffer usage with mindistance_ helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed buffer usage with tile helper.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed constant shape usage.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed constant shape usage with legacy pairwise bool ops.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored a couple of methods to adopt constant shape usage.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed broadcasting with constant shape."

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed const usage with inplace reverse and constant shapes with legacy reduction.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored legacy ops with const shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored sort to adopt constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Corrected sort for constant shape usage.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed constant shape usage with special methods.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored Context to conform with constant shape usage.

Signed-off-by: shugeo <sgazeos@gmail.com>

* CUDA broadcasting headers

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* pairwise/indexreduce/random headers

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* Refactored native ops to adopt constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* legacy reduce3/scalar headers

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* Corrected pullRow signature and tests.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Corrected routines to proper use of constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored tests to use constant shapes properly.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored legacy ops tests to use constant shapes properly.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored buffer usage with NDArray tests.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed native ops tests.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed special concat routine.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed buffer usage with test.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed buffer usage with a test.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored TAD.h and tests.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored calcStrides* routines to use constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed miscelaneous errors with constant shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* NativeOps const changes

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* Corrected definitions for declared functions.

Signed-off-by: shugeo <sgazeos@gmail.com>

* NativeOps const changes

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* few more const changes

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* Fixed const shapes with shape routines.

Signed-off-by: shugeo <sgazeos@gmail.com>

* few more const changes

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* Fixed shape method for broadcastable case.

Signed-off-by: shugeo <sgazeos@gmail.com>

* few more const changes

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* xw_plus_b BP shape fn restored

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* Fixed signatures with broadcasting.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Repaired backprops shape methods for a set of operations.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored broadcast bool for cuda.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored methods for 3 args with const qualifier.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed a couple of kernel signatures for broadcasting.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed kernels signatures for const buffers and shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored pairwise methods to persistent buffers and shapes usage.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt const to buffers and shapes with kernels.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopt const to buffers and shapes with scalar kernels.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored indexreduce kernels signatures to use const buffers and shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored pairwise kernels to adopt cons shapes and buffers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored pairwise bool kernels to adopt cons shapes and buffers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored random special ops to conform with const shapes and buffers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored native ops to conform with const shapes and buffers under cuda platform.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Cosmetical changes only.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed const shapes and buffers error.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Corrected start pos routine.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored methods to conform with const shapes and buffers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored helpers to use proper methods instead.

Signed-off-by: shugeo <sgazeos@gmail.com>

* bunch of changes

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* next bunch of changes

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* next bunch of changes

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* Fixed execScalar declaration.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed execScalar declaration.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Corrected const shape cases with sort and so on.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed const shapes for sort.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored kernel declarations to adopt const shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed kernels declarations to adopt const shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Corrected kernel declarations to adopt const shapes and buffers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed kernels declarations to adopt const shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed segment helpers kernels declarations and so on to adopt const shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed const shape usage with segment and solve helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed kernel declaration with adjustWeight helper.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed cuda implementations for constant shape helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted const shape usage with kernels.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adopted top_k kernels to use const shapes and buffers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Corrected kernels declarations to adopt const shapes with helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored NDArray definitions to adopt const shapes and buffers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed const shapes with image suppression helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Slight improvement with buffers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored buffer usage.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored buffer usage with tests.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed const shape usage with definitions.

Signed-off-by: shugeo <sgazeos@gmail.com>

* minor updates on cpu side

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* Refactored const shape usage with ConstantDescritor and native ops with cuda platform.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored tear and tile kernels to adopt with const shapes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* softmax_loop fix

Signed-off-by: raver119 <raver119@gmail.com>

* update missing signature

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* softmax again

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* few more missing consts

Signed-off-by: raver119 <raver119@gmail.com>

* new methods updated

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

Co-authored-by: shugeo <sgazeos@gmail.com>
2020-05-09 08:06:14 +03:00

677 lines
31 KiB
C++

/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <system/op_boilerplate.h>
#if NOT_EXCLUDED(OP_strided_slice)
#include <array>
#include <ops/declarable/CustomOperations.h>
#include <helpers/ShapeUtils.h>
#include <helpers/BitwiseUtils.h>
namespace sd {
namespace ops {
constexpr int kShrinkAxis = -1, kNewAxis = -2;
struct StridedSliceSparseSpec {
int dims;
int num_add_axis_after_ellipsis;
std::vector<int>* begin_tensor;
const std::vector<int>* end_tensor;
const std::vector<int>* strides_tensor;
const int begin_mask, end_mask;
int ellipsis_mask;
const int new_axis_mask, shrink_axis_mask;
};
struct StridedSliceDenseSpec {
const int dims;
int begin_mask;
int end_mask;
bool begin_valid;
bool end_valid;
std::vector<int>& begin;
std::vector<int>& end;
std::vector<int>& strides;
std::vector<int> final_shape_gather_indices;
int shrink_axis_mask;
public:
bool buildDenseSpec(StridedSliceSparseSpec& sparse_spec) {
if (this->begin.size() < dims)
this->begin.resize(dims);
if (this->end.size() < dims)
this->end.resize(dims);
if (this->strides.size() < dims)
this->strides.resize(dims);
this->begin_mask = 0;
this->end_mask = 0;
this->shrink_axis_mask = 0;
{
int full_index = 0;
this->begin_valid = sparse_spec.begin_tensor != nullptr;
this->end_valid = sparse_spec.end_tensor != nullptr;
for (int e = 0; e < sparse_spec.dims; e++) {
if ((1 << e) & sparse_spec.ellipsis_mask) {
int next_index = sd::math::nd4j_min<int>(this->dims - (sparse_spec.dims - e) + 1 + sparse_spec.num_add_axis_after_ellipsis, this->dims);
for (; full_index < next_index; full_index++) {
// new_axis' aren't real axis so you have to skip
this->begin[full_index] = this->end[full_index] = 0;
this->strides[full_index] = 1;
this->begin_mask |= (1 << full_index);
this->end_mask |= (1 << full_index);
this->final_shape_gather_indices.push_back(full_index);
}
} else if ((1 << e) & sparse_spec.new_axis_mask) {
this->final_shape_gather_indices.emplace_back(kNewAxis);
} else {
if (full_index == this->begin.size()) {
nd4j_printf("Index out of range: %i out of %i\n", full_index, this->dims);
return false;
}
// Gather slicing spec into appropriate index
if (sparse_spec.begin_tensor != nullptr)
this->begin[full_index] = sparse_spec.begin_tensor->at(e);
if (sparse_spec.end_tensor != nullptr)
this->end[full_index] = sparse_spec.end_tensor->at(e);
this->strides[full_index] = sparse_spec.strides_tensor->at(e);
if (sparse_spec.begin_mask & (1 << e))
this->begin_mask |= (1 << full_index);
if (sparse_spec.end_mask & (1 << e))
this->end_mask |= (1 << full_index);
// If shrink, record where to get the dimensionality from (i.e.
// new_axis creates a fake 1 size dimension. Also remember shrink
// axis (now in dense form) so we can ignore dense->end below.
if (sparse_spec.shrink_axis_mask & (1 << e)) {
this->final_shape_gather_indices.push_back(kShrinkAxis);
this->shrink_axis_mask |= (1 << full_index);
} else {
this->final_shape_gather_indices.push_back(full_index);
}
full_index++;
}
}
}
return true;
}
};
void vectorize(std::vector<Nd4jLong>& input_shape) {
if (input_shape.size() == 2 && input_shape[0] == 1) {
int v = input_shape[1];
input_shape.clear();
input_shape.emplace_back(v);
}
}
bool _preprocess_strided_slice(std::vector<Nd4jLong>* indicesList, std::vector<Nd4jLong>* final_shape, std::vector<Nd4jLong>& input_shape, std::vector<int>& begin, std::vector<int>& end, std::vector<int>& strides, int begin_mask, int ellipsis_mask, int end_mask, int new_axis_mask, int shrink_axis_mask, bool* is_identity, bool* is_simple_slice, bool* slice_dim0) {
std::vector<int> preshape;
bool ellipsis_seen = false;
StridedSliceSparseSpec sparse_spec = {(int) strides.size(),
0,
&begin,
&end,
&strides,
begin_mask,
end_mask,
ellipsis_mask,
new_axis_mask,
shrink_axis_mask};
for (int i = 0; i < sparse_spec.dims; i++) {
if (ellipsis_seen && ((1 << i) & new_axis_mask) != 0) {
sparse_spec.num_add_axis_after_ellipsis++;
}
if ((1 << i) & ellipsis_mask) {
ellipsis_seen = true;
}
}
// If no ellipsis insert one at the end
if (!ellipsis_seen) {
sparse_spec.ellipsis_mask |= (1 << sparse_spec.dims);
sparse_spec.dims++; // this effects loop iteration below
}
StridedSliceDenseSpec dense_spec = {(int) input_shape.size(), 0, 0, false, false, begin, end, strides};
if (!dense_spec.buildDenseSpec(sparse_spec))
return false;
//nd4j_printv("Input shape: ", input_shape);
for (int e = 0; e < (int) input_shape.size(); e++) {
int begin_idx = begin[e];
int end_idx = end[e];
int stride_idx = strides[e];
int size_idx = input_shape[e];
bool shrink_i = (dense_spec.shrink_axis_mask & (1 << e));
if (stride_idx == 0) {
nd4j_printf("Stride is 0 at index %i\n", e);
return false;
}
if (size_idx == -1) {
preshape.emplace_back(shrink_i ? 1 : -1);
continue;
}
const std::array<int, 2> masks = {{dense_spec.begin_mask & (1 << e), dense_spec.end_mask & (1 << e)}};
const std::array<int, 2> valid_range = {{stride_idx > 0 ? 0 : -1, stride_idx > 0 ? size_idx : size_idx - 1}};
auto canonical = [stride_idx, e, size_idx, masks, valid_range](int x, int c) {
if (masks[c]) {
return stride_idx > 0 ? valid_range[c] : valid_range[(c + 1) & 1];
} else {
int x_fwd = x < 0 ? size_idx + x : x; // make negative indices positive
return x_fwd < valid_range[0] ? valid_range[0] : x_fwd > valid_range[1] ? valid_range[1] : x_fwd;
}
};
if (shrink_i && stride_idx <= 0) {
nd4j_printf("StridedSlice: only stride 1 allowed on non-range indexing\n", e);
return false;
}
(*is_simple_slice) &= stride_idx == 1;
const bool begin_and_end_masked = (begin_mask & (1 << e)) && (end_mask & (1 << e));
if (dense_spec.begin_valid && dense_spec.end_valid) {
if (shrink_i) {
int x_fwd = begin_idx < 0 ? size_idx + begin_idx : begin_idx;
begin_idx = x_fwd;
end_idx = begin_idx + 1;
if (x_fwd < 0 || x_fwd >= size_idx) {
nd4j_printf("slice index %i of dimension %i out of bounds.\n", begin_idx, e);
return false;
}
} else {
begin_idx = canonical(begin_idx, 0);
end_idx = canonical(end_idx, 1);
}
} else {
(*is_identity) &= stride_idx == 1 && begin_and_end_masked;
(*slice_dim0) &= (e == 0 && stride_idx == 1) || begin_and_end_masked;
}
int interval_length = 1;
bool known_interval = false;
if (dense_spec.begin_valid && dense_spec.end_valid) {
interval_length = end_idx - begin_idx;
known_interval = true;
} else if (shrink_i) {
interval_length = 1;
known_interval = true;
} else if (begin_and_end_masked) {
if (size_idx > 0) {
if (stride_idx < 0) {
interval_length = -size_idx;
} else {
interval_length = size_idx;
}
known_interval = true;
}
}
if (known_interval) {
int size_i;
if (interval_length == 0 || ((interval_length < 0) != (stride_idx < 0))) {
size_i = input_shape.size() == 2 && input_shape[0] == 1? 1: 0;
} else {
size_i = interval_length / stride_idx + (interval_length % stride_idx != 0 ? 1 : 0);
}
if (indicesList != nullptr) {
if (interval_length > 1) {
indicesList->push_back(begin_idx);
indicesList->push_back(end_idx);
indicesList->push_back(stride_idx);
// (*indicesList)[3*e] = begin_idx;
// (*indicesList)[3*e+1] = end_idx;
// (*indicesList)[3*e+2] = stride_idx;
}
else if (interval_length == 1) {
indicesList->push_back(begin_idx);
indicesList->push_back(begin_idx + 1);
indicesList->push_back(1);
// (*indicesList)[3*e] = begin_idx;
// (*indicesList)[3*e+1] = begin_idx + 1;
// (*indicesList)[3*e+2] = 1;
}
}
preshape.emplace_back(size_i);
} else {
preshape.emplace_back(-1);
}
}
std::vector<int> postshape;
//nd4j_printv("Preshape: ", preshape);
final_shape->clear();
for (auto gather_index : dense_spec.final_shape_gather_indices) {
if (gather_index >= 0) {
if (preshape.size() > gather_index)
final_shape->emplace_back(preshape.at(gather_index));
else
final_shape->emplace_back(1);
} else if (gather_index == kNewAxis) {
final_shape->emplace_back(1);
}
}
//nd4j_printv("Preshape: ", preshape);
//nd4j_printv("Postshape: ", *final_shape);
return true;
}
CUSTOM_OP_IMPL(strided_slice, 1, 1, false, 0, 5) {
auto x = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0);
if (z->isEmpty()) {
return ND4J_STATUS_OK;
}
int begin_mask = INT_ARG(0);
int ellipsis_mask = INT_ARG(1);
int end_mask = INT_ARG(2);
int new_axis_mask = INT_ARG(3);
int shrink_axis_mask = INT_ARG(4);
int dim_values = 0; //block.getIArguments()->size() - 5;
int delta = 0; //dim_values % 3;
int elements = 0; //dim_values / 3;
std::vector<int> begin;
std::vector<int> end;
std::vector<int> strides;
bool isLive = false;
std::vector<int> args;
// statically evaluated
if (block.getIArguments()->size() > 5) {
dim_values = block.getIArguments()->size() - 5;
delta = dim_values % 3;
elements = dim_values / 3;
for (int e = 5; e < block.getIArguments()->size(); e++)
args.emplace_back(INT_ARG(e));
REQUIRE_TRUE(delta == 0, 0, "StridedSlice: Number of Integer arguments should be equal to input rank x 3 = %i, but got %i instead", (x->rankOf() * 3), dim_values);
ShapeUtils::copyVectorPart(begin, args, elements, 0);
ShapeUtils::copyVectorPart(end, args, elements, elements);
ShapeUtils::copyVectorPart(strides, args, elements, elements * 2);
} else if (block.width() > 1) {
isLive = true;
auto v_begin = INPUT_VARIABLE(1);
auto v_end = INPUT_VARIABLE(2);
elements = v_begin->lengthOf();
REQUIRE_TRUE(v_begin->lengthOf() == v_end->lengthOf(), 0, "StridedSlice: Length of begin/end should match, but got %i vs %i instead", (int) v_begin->lengthOf(), (int) v_end->lengthOf());
REQUIRE_TRUE((v_begin->rankOf() == 1 ) && (v_begin->rankOf() == v_end->rankOf()), 0, "StridedSlice: Rank of begin and ends should be 1, but %i given instead", (int)v_end->rankOf());
for (int e = 0; e < v_begin->lengthOf(); e++)
begin.emplace_back(v_begin->e<int>(e));
for (int e = 0; e < v_end->lengthOf(); e++)
end.emplace_back(v_end->e<int>(e));
if (block.width() > 3) {
auto v_stride = INPUT_VARIABLE(3);
REQUIRE_TRUE(v_stride->lengthOf() == v_begin->lengthOf(), 0, "StridedSlice: Length of begin/end/stride should match, but got %i vs %i vs %i instead", (int) v_begin->lengthOf(), (int) v_end->lengthOf(), (int) v_stride->lengthOf());
REQUIRE_TRUE((v_begin->rankOf() == v_stride->rankOf()), 0, "StridedSlice: Rank of begin and ends should be %i, but %i given instead", (int) v_begin->rankOf(), v_stride->rankOf());
for (int e = 0; e < v_stride->lengthOf(); e++)
strides.emplace_back(v_stride->e<int>(e));
} else {
for (int e = 0; e < v_begin->lengthOf(); e++)
strides.emplace_back(1);
}
} else {
REQUIRE_TRUE(false, 0, "StridedSlice: Can't find begin/end/stride information neither in IArguments or in input arrays");
}
// validation of begin and start
std::vector<int> ignoreBegin = BitwiseUtils::valueBits(begin_mask);
std::vector<int> ignoreEnd = BitwiseUtils::valueBits(end_mask);
std::vector<int> addAxes = BitwiseUtils::valueBits(new_axis_mask);
std::vector<int> moveAxes = BitwiseUtils::valueBits(shrink_axis_mask);
if (shrink_axis_mask == 0)
for (int dim = 0, b = 0, e = 0; dim < x->rankOf(); ++dim) {
if(moveAxes[dim])
continue;
if(b < begin.size() && !ignoreBegin[b] && !addAxes[dim]) {
int first = strides[b] > 0 ? begin[b] : math::nd4j_abs<int>(begin[b]) - 1;
REQUIRE_TRUE(first <= x->sizeAt(dim), 0, "StridedSlice: begin index should be <= corresponding dimension of input array, but got end_index = %i for dimension %i!", begin[b], dim);
}
if(e < end.size() && !ignoreEnd[e] && !addAxes[dim]) {
int last = strides[e] > 0 ? end[e] : math::nd4j_abs<int>(end[e]) - 1;
REQUIRE_TRUE(last <= x->sizeAt(dim), 0, "StridedSlice: end index should be <= corresponding dimension of input array, but got end_index = %i for dimension %i!", end[e], dim);
}
++b;
++e;
}
std::vector<Nd4jLong> indices;
auto input_shape = x->getShapeAsVector();
std::vector<Nd4jLong> final_shape;
bool is_identity;
bool is_simple_slice;
bool is_dim0;
// FIXME: remove this method once we get 1D vectors supported
//vectorize(input_shape);
REQUIRE_TRUE(_preprocess_strided_slice(&indices, &final_shape, input_shape, begin, end, strides, begin_mask, ellipsis_mask, end_mask, new_axis_mask, shrink_axis_mask, &is_identity, &is_simple_slice, &is_dim0), 0, "StridedSlice: shape calculation failed");
// if(z->lengthOf() == 1 && !z->isEmpty() && (input_shape.size() == 2 && input_shape[0] == 1)) { //(indices.size() == 6) && (indices[2] - indices[0] == 1)) {
// z->assign(x->e<float>(indices[0]));
// }
// else {
if (indices.size()) {
Nd4jLong* subArrShapeInfo = nullptr;
ALLOCATE(subArrShapeInfo, block.getWorkspace(), shape::shapeInfoLength(x->rankOf()), Nd4jLong);
Nd4jLong offset;
shape::calcSubArrShapeInfoAndOffset(indices.data(), x->shapeInfo(), subArrShapeInfo, offset, true, true);
auto subArrShapeInfoPack = ConstantShapeHelper::getInstance()->bufferForShapeInfo(subArrShapeInfo);
NDArray::prepareSpecialUse({z}, {x});
NativeOpExecutioner::execTransformAny(block.launchContext(), sd::transform::Assign,
x->bufferWithOffset(offset), reinterpret_cast<Nd4jLong *>(subArrShapeInfoPack.primary()),
x->specialBufferWithOffset(offset), reinterpret_cast<Nd4jLong *>(subArrShapeInfoPack.special()),
z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(),
nullptr, nullptr, nullptr, true);
NDArray::registerSpecialUse({z}, {x});
RELEASE(subArrShapeInfo, block.getWorkspace());
}
else if (!z->isEmpty()){
z->assign(x->e(0));
}
return Status::OK();
}
DECLARE_SYN(stridedslice, strided_slice);
DECLARE_SHAPE_FN(strided_slice) {
auto inShape = inputShape->at(0);
int begin_mask = INT_ARG(0);
int ellipsis_mask = INT_ARG(1);
int end_mask = INT_ARG(2);
int new_axis_mask = INT_ARG(3);
int shrink_axis_mask = INT_ARG(4);
int x_rank = shape::rank(inShape);
int dim_values = block.getIArguments()->size() - 5;
int delta = dim_values % 3;
int elements = dim_values / 3;
std::vector<int> begin;
std::vector<int> end;
std::vector<int> strides;
// if that's live - shape will be resolved in runtime
if (block.width() > 1) {
begin = INPUT_VARIABLE(1)->template asVectorT<int>();
end = INPUT_VARIABLE(2)->template asVectorT<int>();
strides = INPUT_VARIABLE(3)->template asVectorT<int>();
} else if (dim_values > 0) {
int delta2 = dim_values / x_rank;
std::vector<int> args;
for (int e = 5; e < block.getIArguments()->size(); e++)
args.emplace_back(INT_ARG(e));
// FIXME: propably template required here
ShapeUtils::copyVectorPart(begin, args, elements, 0);
ShapeUtils::copyVectorPart(end, args, elements, elements);
ShapeUtils::copyVectorPart(strides, args, elements, elements * 2);
}
REQUIRE_TRUE(begin.size() > 0 && end.size() > 0 && strides.size() > 0, 0, "Strided_Slice: empty arguments");
// validation of begin and start
std::vector<int> ignoreBegin = BitwiseUtils::valueBits(begin_mask);
std::vector<int> ignoreEnd = BitwiseUtils::valueBits(end_mask);
std::vector<int> addAxes = BitwiseUtils::valueBits(new_axis_mask);
std::vector<int> moveAxes = BitwiseUtils::valueBits(shrink_axis_mask);
//if (0 == shrink_axis_mask)
if (false)
for (int dim = 0, b = 0, e = 0; dim < x_rank; ++dim) {
if(moveAxes[dim])
continue;
if(b < begin.size() && !ignoreBegin[b] && !addAxes[dim]) {
int first = strides[b] > 0 ? begin[b] : math::nd4j_abs<int>(begin[b]) - 1;
REQUIRE_TRUE(first <= inShape[dim + 1], 0, "StridedSlice: begin index should be <= corresponding dimension of input array, but got end_index = %i for dimension %i!", begin[b], dim);
}
if(e < end.size() && !ignoreEnd[e] && !addAxes[dim]) {
int last = strides[e] > 0 ? end[e] : math::nd4j_abs<int>(end[e]) - 1;
REQUIRE_TRUE(last <= inShape[dim + 1], 0, "StridedSlice: end index should be <= corresponding dimension of input array, but got end_index = %i for dimension %i!", end[e], dim);
}
++b;
++e;
}
std::vector<Nd4jLong> input_shape; //(shape::rank(inShape));
auto inputLen = shape::length(inShape);
std::vector<Nd4jLong> shape;
auto rank = shape::rank(inShape);
auto shortShape = shape::shapeOf(inShape);
for (auto e = 0; e < rank; e++)
input_shape.emplace_back(shortShape[e]);
bool is_identity;
bool is_simple_slice;
bool is_dim0;
std::vector<Nd4jLong> indices;
bool result = _preprocess_strided_slice(&indices, &shape, input_shape, begin, end, strides, begin_mask, ellipsis_mask, end_mask, new_axis_mask, shrink_axis_mask, &is_identity, &is_simple_slice, &is_dim0);
if (indices.size()) {
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), 'c',
shape);
// if (inputLen > 1) {
// newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), 'c',
// shape);
// } else {
// newShape = ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inShape));
// }
return SHAPELIST(newShape);
}
return SHAPELIST(ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(inShape)));
}
CUSTOM_OP_IMPL(strided_slice_bp, 2, 1, false, 0, 5) {
auto x = INPUT_VARIABLE(0);
auto epsNext = INPUT_VARIABLE(1);
auto output = OUTPUT_VARIABLE(0);
int begin_mask = INT_ARG(0);
int ellipsis_mask = INT_ARG(1);
int end_mask = INT_ARG(2);
int new_axis_mask = INT_ARG(3);
int shrink_axis_mask = INT_ARG(4);
int dim_values = 0; //block.getIArguments()->size() - 5;
int delta = 0; //dim_values % 3;
int elements = 0; //dim_values / 3;
std::vector<int> begin;
std::vector<int> end;
std::vector<int> strides;
bool isLive = false;
std::vector<int> args;
// statically evaluated
if (block.getIArguments()->size() > 5) {
dim_values = block.getIArguments()->size() - 5;
delta = dim_values % 3;
elements = dim_values / 3;
for (int e = 5; e < block.getIArguments()->size(); e++)
args.emplace_back(INT_ARG(e));
REQUIRE_TRUE(delta == 0, 0, "StridedSliceBP: Number of Integer arguments should be equal to input rank x 3 = %i, but got %i instead", (x->rankOf() * 3), dim_values);
ShapeUtils::copyVectorPart(begin, args, elements, 0);
ShapeUtils::copyVectorPart(end, args, elements, elements);
ShapeUtils::copyVectorPart(strides, args, elements, elements * 2);
} else if (block.width() >= 3) {
isLive = true;
auto v_begin = INPUT_VARIABLE(2);
auto v_end = INPUT_VARIABLE(3);
elements = v_begin->lengthOf();
REQUIRE_TRUE(v_begin->lengthOf() == v_end->lengthOf(), 0, "StridedSliceBP: Length of begin/end should match, but got %i vs %i instead", (int) v_begin->lengthOf(), (int) v_end->lengthOf());
for (int e = 0; e < v_begin->lengthOf(); e++)
begin.emplace_back(v_begin->e<int>(e));
for (int e = 0; e < v_end->lengthOf(); e++)
end.emplace_back(v_end->e<int>(e));
if (block.width() >= 4) {
auto v_stride = INPUT_VARIABLE(4);
REQUIRE_TRUE(v_stride->lengthOf() == v_begin->lengthOf(), 0, "StridedSliceBP: Length of begin/end/stride should match, but got %i vs %i vs %i instead", (int) v_begin->lengthOf(), (int) v_end->lengthOf(), (int) v_stride->lengthOf());
for (int e = 0; e < v_stride->lengthOf(); e++)
strides.emplace_back(v_stride->e<int>(e));
} else {
for (int e = 0; e < v_begin->lengthOf(); e++)
strides.emplace_back(1);
}
} else {
REQUIRE_TRUE(false, 0, "StridedSliceBP: Can't find begin/end/stride information neither in IArguments or in input arrays");
}
// validation of begin and start
std::vector<int> ignoreBegin = BitwiseUtils::valueBits(begin_mask);
std::vector<int> ignoreEnd = BitwiseUtils::valueBits(end_mask);
std::vector<int> addAxes = BitwiseUtils::valueBits(new_axis_mask);
std::vector<int> moveAxes = BitwiseUtils::valueBits(shrink_axis_mask);
for (int dim = 0, b = 0, e = 0; dim < x->rankOf(); ++dim) {
if(moveAxes[dim])
continue;
if(b < begin.size() && !ignoreBegin[b] && !addAxes[dim]) {
int first = strides[b] > 0 ? begin[b] : math::nd4j_abs<int>(begin[b]) - 1;
REQUIRE_TRUE(first <= x->sizeAt(dim), 0, "StridedSlice: begin index should be <= corresponding dimension of input array, but got end_index = %i for dimension %i!", begin[b], dim);
}
if(e < end.size() && !ignoreEnd[e] && !addAxes[dim]) {
int last = strides[e] > 0 ? end[e] : math::nd4j_abs<int>(end[e]) - 1;
REQUIRE_TRUE(last <= x->sizeAt(dim), 0, "StridedSlice: end index should be <= corresponding dimension of input array, but got end_index = %i for dimension %i!", end[e], dim);
}
++b;
++e;
}
auto input_shape = x->getShapeAsVector();
std::vector<Nd4jLong> indices;
std::vector<Nd4jLong> final_shape;
bool is_identity;
bool is_simple_slice;
bool is_dim0;
// FIXME: remove this method once we get 1D vectors supported
vectorize(input_shape);
REQUIRE_TRUE(_preprocess_strided_slice(&indices, &final_shape, input_shape, begin, end, strides, begin_mask, ellipsis_mask, end_mask, new_axis_mask, shrink_axis_mask, &is_identity, &is_simple_slice, &is_dim0), 0, "StridedSliceBP: shape calculation failed");
//REQUIRE_TRUE(epsNext->isSameShape(final_shape), 0, "StridedSlice_bp: gradOut shape should be equals to output from strided_slice op.");
//Zero output array, so unused elements have 0 gradient
output->nullify();
//
// the first case: only for scalar gradient step
if(epsNext->lengthOf() == 1 && (indices.size() == 3 && (indices[1] - indices[0]) == 1 || (indices[2] - indices[0] == 1))) {
output->p(indices[0], *epsNext);
}
else { // else for other cases
auto sub = (*output)(indices, true, true);
sub.assign(epsNext);
}
return Status::OK();
}
DECLARE_SHAPE_FN(strided_slice_bp) {
auto inShape = inputShape->at(0);
Nd4jLong *newShape;
COPY_SHAPE(inShape, newShape);
return SHAPELIST(CONSTANT(newShape));
}
DECLARE_TYPES(strided_slice) {
getOpDescriptor()
->setAllowedInputTypes(sd::DataType::ANY)
->setSameMode(true);
}
DECLARE_TYPES(strided_slice_bp) {
getOpDescriptor()
->setAllowedInputTypes(sd::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
}
}
#endif