679 lines
31 KiB
C++
679 lines
31 KiB
C++
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include <system/op_boilerplate.h>
|
|
#if NOT_EXCLUDED(OP_strided_slice)
|
|
|
|
#include <array>
|
|
#include <ops/declarable/CustomOperations.h>
|
|
#include <helpers/ShapeUtils.h>
|
|
#include <helpers/BitwiseUtils.h>
|
|
|
|
namespace sd {
|
|
namespace ops {
|
|
|
|
constexpr int kShrinkAxis = -1, kNewAxis = -2;
|
|
|
|
struct StridedSliceSparseSpec {
|
|
int dims;
|
|
int num_add_axis_after_ellipsis;
|
|
std::vector<int>* begin_tensor;
|
|
const std::vector<int>* end_tensor;
|
|
const std::vector<int>* strides_tensor;
|
|
const int begin_mask, end_mask;
|
|
int ellipsis_mask;
|
|
const int new_axis_mask, shrink_axis_mask;
|
|
};
|
|
|
|
struct StridedSliceDenseSpec {
|
|
const int dims;
|
|
int begin_mask;
|
|
int end_mask;
|
|
bool begin_valid;
|
|
bool end_valid;
|
|
std::vector<int>& begin;
|
|
std::vector<int>& end;
|
|
std::vector<int>& strides;
|
|
std::vector<int> final_shape_gather_indices;
|
|
int shrink_axis_mask;
|
|
|
|
public:
|
|
bool buildDenseSpec(StridedSliceSparseSpec& sparse_spec) {
|
|
if (this->begin.size() < dims)
|
|
this->begin.resize(dims);
|
|
|
|
if (this->end.size() < dims)
|
|
this->end.resize(dims);
|
|
|
|
if (this->strides.size() < dims)
|
|
this->strides.resize(dims);
|
|
this->begin_mask = 0;
|
|
this->end_mask = 0;
|
|
this->shrink_axis_mask = 0;
|
|
{
|
|
int full_index = 0;
|
|
|
|
this->begin_valid = sparse_spec.begin_tensor != nullptr;
|
|
this->end_valid = sparse_spec.end_tensor != nullptr;
|
|
|
|
for (int e = 0; e < sparse_spec.dims; e++) {
|
|
if ((1 << e) & sparse_spec.ellipsis_mask) {
|
|
int next_index = sd::math::nd4j_min<int>(this->dims - (sparse_spec.dims - e) + 1 + sparse_spec.num_add_axis_after_ellipsis, this->dims);
|
|
|
|
for (; full_index < next_index; full_index++) {
|
|
// new_axis' aren't real axis so you have to skip
|
|
this->begin[full_index] = this->end[full_index] = 0;
|
|
this->strides[full_index] = 1;
|
|
this->begin_mask |= (1 << full_index);
|
|
this->end_mask |= (1 << full_index);
|
|
this->final_shape_gather_indices.push_back(full_index);
|
|
}
|
|
} else if ((1 << e) & sparse_spec.new_axis_mask) {
|
|
this->final_shape_gather_indices.emplace_back(kNewAxis);
|
|
} else {
|
|
if (full_index == this->begin.size()) {
|
|
nd4j_printf("Index out of range: %i out of %i\n", full_index, this->dims);
|
|
return false;
|
|
}
|
|
|
|
// Gather slicing spec into appropriate index
|
|
if (sparse_spec.begin_tensor != nullptr)
|
|
this->begin[full_index] = sparse_spec.begin_tensor->at(e);
|
|
|
|
|
|
if (sparse_spec.end_tensor != nullptr)
|
|
this->end[full_index] = sparse_spec.end_tensor->at(e);
|
|
|
|
this->strides[full_index] = sparse_spec.strides_tensor->at(e);
|
|
|
|
if (sparse_spec.begin_mask & (1 << e))
|
|
this->begin_mask |= (1 << full_index);
|
|
|
|
|
|
if (sparse_spec.end_mask & (1 << e))
|
|
this->end_mask |= (1 << full_index);
|
|
|
|
|
|
// If shrink, record where to get the dimensionality from (i.e.
|
|
// new_axis creates a fake 1 size dimension. Also remember shrink
|
|
// axis (now in dense form) so we can ignore dense->end below.
|
|
if (sparse_spec.shrink_axis_mask & (1 << e)) {
|
|
this->final_shape_gather_indices.push_back(kShrinkAxis);
|
|
this->shrink_axis_mask |= (1 << full_index);
|
|
} else {
|
|
this->final_shape_gather_indices.push_back(full_index);
|
|
}
|
|
full_index++;
|
|
}
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
};
|
|
|
|
void vectorize(std::vector<Nd4jLong>& input_shape) {
|
|
if (input_shape.size() == 2 && input_shape[0] == 1) {
|
|
int v = input_shape[1];
|
|
input_shape.clear();
|
|
input_shape.emplace_back(v);
|
|
}
|
|
}
|
|
|
|
bool _preprocess_strided_slice(std::vector<Nd4jLong>* indicesList, std::vector<Nd4jLong>* final_shape, std::vector<Nd4jLong>& input_shape, std::vector<int>& begin, std::vector<int>& end, std::vector<int>& strides, int begin_mask, int ellipsis_mask, int end_mask, int new_axis_mask, int shrink_axis_mask, bool* is_identity, bool* is_simple_slice, bool* slice_dim0) {
|
|
std::vector<int> preshape;
|
|
|
|
bool ellipsis_seen = false;
|
|
|
|
StridedSliceSparseSpec sparse_spec = {(int) strides.size(),
|
|
0,
|
|
&begin,
|
|
&end,
|
|
&strides,
|
|
begin_mask,
|
|
end_mask,
|
|
ellipsis_mask,
|
|
new_axis_mask,
|
|
shrink_axis_mask};
|
|
|
|
for (int i = 0; i < sparse_spec.dims; i++) {
|
|
if (ellipsis_seen && ((1 << i) & new_axis_mask) != 0) {
|
|
sparse_spec.num_add_axis_after_ellipsis++;
|
|
}
|
|
if ((1 << i) & ellipsis_mask) {
|
|
ellipsis_seen = true;
|
|
}
|
|
}
|
|
// If no ellipsis insert one at the end
|
|
if (!ellipsis_seen) {
|
|
sparse_spec.ellipsis_mask |= (1 << sparse_spec.dims);
|
|
sparse_spec.dims++; // this effects loop iteration below
|
|
}
|
|
|
|
StridedSliceDenseSpec dense_spec = {(int) input_shape.size(), 0, 0, false, false, begin, end, strides};
|
|
if (!dense_spec.buildDenseSpec(sparse_spec))
|
|
return false;
|
|
|
|
//nd4j_printv("Input shape: ", input_shape);
|
|
|
|
for (int e = 0; e < (int) input_shape.size(); e++) {
|
|
int begin_idx = begin[e];
|
|
int end_idx = end[e];
|
|
int stride_idx = strides[e];
|
|
int size_idx = input_shape[e];
|
|
|
|
bool shrink_i = (dense_spec.shrink_axis_mask & (1 << e));
|
|
|
|
if (stride_idx == 0) {
|
|
nd4j_printf("Stride is 0 at index %i\n", e);
|
|
return false;
|
|
}
|
|
if (size_idx == -1) {
|
|
preshape.emplace_back(shrink_i ? 1 : -1);
|
|
continue;
|
|
}
|
|
|
|
const std::array<int, 2> masks = {{dense_spec.begin_mask & (1 << e), dense_spec.end_mask & (1 << e)}};
|
|
const std::array<int, 2> valid_range = {{stride_idx > 0 ? 0 : -1, stride_idx > 0 ? size_idx : size_idx - 1}};
|
|
|
|
auto canonical = [stride_idx, e, size_idx, masks, valid_range](int x, int c) {
|
|
if (masks[c]) {
|
|
return stride_idx > 0 ? valid_range[c] : valid_range[(c + 1) & 1];
|
|
} else {
|
|
int x_fwd = x < 0 ? size_idx + x : x; // make negative indices positive
|
|
return x_fwd < valid_range[0] ? valid_range[0] : x_fwd > valid_range[1] ? valid_range[1] : x_fwd;
|
|
}
|
|
};
|
|
|
|
if (shrink_i && stride_idx <= 0) {
|
|
nd4j_printf("StridedSlice: only stride 1 allowed on non-range indexing\n", e);
|
|
return false;
|
|
}
|
|
|
|
(*is_simple_slice) &= stride_idx == 1;
|
|
|
|
const bool begin_and_end_masked = (begin_mask & (1 << e)) && (end_mask & (1 << e));
|
|
|
|
if (dense_spec.begin_valid && dense_spec.end_valid) {
|
|
if (shrink_i) {
|
|
int x_fwd = begin_idx < 0 ? size_idx + begin_idx : begin_idx;
|
|
begin_idx = x_fwd;
|
|
end_idx = begin_idx + 1;
|
|
if (x_fwd < 0 || x_fwd >= size_idx) {
|
|
nd4j_printf("slice index %i of dimension %i out of bounds.\n", begin_idx, e);
|
|
return false;
|
|
}
|
|
} else {
|
|
begin_idx = canonical(begin_idx, 0);
|
|
end_idx = canonical(end_idx, 1);
|
|
}
|
|
} else {
|
|
(*is_identity) &= stride_idx == 1 && begin_and_end_masked;
|
|
(*slice_dim0) &= (e == 0 && stride_idx == 1) || begin_and_end_masked;
|
|
}
|
|
|
|
int interval_length = 1;
|
|
bool known_interval = false;
|
|
if (dense_spec.begin_valid && dense_spec.end_valid) {
|
|
interval_length = end_idx - begin_idx;
|
|
known_interval = true;
|
|
} else if (shrink_i) {
|
|
interval_length = 1;
|
|
known_interval = true;
|
|
} else if (begin_and_end_masked) {
|
|
if (size_idx > 0) {
|
|
if (stride_idx < 0) {
|
|
interval_length = -size_idx;
|
|
} else {
|
|
interval_length = size_idx;
|
|
}
|
|
|
|
known_interval = true;
|
|
}
|
|
}
|
|
|
|
if (known_interval) {
|
|
int size_i;
|
|
if (interval_length == 0 || ((interval_length < 0) != (stride_idx < 0))) {
|
|
size_i = input_shape.size() == 2 && input_shape[0] == 1? 1: 0;
|
|
} else {
|
|
size_i = interval_length / stride_idx + (interval_length % stride_idx != 0 ? 1 : 0);
|
|
}
|
|
|
|
if (indicesList != nullptr) {
|
|
if (interval_length > 1) {
|
|
indicesList->push_back(begin_idx);
|
|
indicesList->push_back(end_idx);
|
|
indicesList->push_back(stride_idx);
|
|
// (*indicesList)[3*e] = begin_idx;
|
|
// (*indicesList)[3*e+1] = end_idx;
|
|
// (*indicesList)[3*e+2] = stride_idx;
|
|
}
|
|
else if (interval_length == 1) {
|
|
indicesList->push_back(begin_idx);
|
|
indicesList->push_back(begin_idx + 1);
|
|
indicesList->push_back(1);
|
|
// (*indicesList)[3*e] = begin_idx;
|
|
// (*indicesList)[3*e+1] = begin_idx + 1;
|
|
// (*indicesList)[3*e+2] = 1;
|
|
}
|
|
}
|
|
|
|
preshape.emplace_back(size_i);
|
|
} else {
|
|
preshape.emplace_back(-1);
|
|
}
|
|
}
|
|
|
|
|
|
std::vector<int> postshape;
|
|
//nd4j_printv("Preshape: ", preshape);
|
|
|
|
final_shape->clear();
|
|
for (auto gather_index : dense_spec.final_shape_gather_indices) {
|
|
if (gather_index >= 0) {
|
|
if (preshape.size() > gather_index)
|
|
final_shape->emplace_back(preshape.at(gather_index));
|
|
else
|
|
final_shape->emplace_back(1);
|
|
} else if (gather_index == kNewAxis) {
|
|
final_shape->emplace_back(1);
|
|
}
|
|
}
|
|
|
|
//nd4j_printv("Preshape: ", preshape);
|
|
//nd4j_printv("Postshape: ", *final_shape);
|
|
|
|
return true;
|
|
}
|
|
|
|
|
|
CUSTOM_OP_IMPL(strided_slice, 1, 1, false, 0, 5) {
|
|
auto x = INPUT_VARIABLE(0);
|
|
auto z = OUTPUT_VARIABLE(0);
|
|
if (z->isEmpty() || z->lengthOf() == 0) {
|
|
return ND4J_STATUS_OK;
|
|
}
|
|
|
|
int begin_mask = INT_ARG(0);
|
|
int ellipsis_mask = INT_ARG(1);
|
|
int end_mask = INT_ARG(2);
|
|
int new_axis_mask = INT_ARG(3);
|
|
int shrink_axis_mask = INT_ARG(4);
|
|
|
|
int dim_values = 0; //block.getIArguments()->size() - 5;
|
|
int delta = 0; //dim_values % 3;
|
|
int elements = 0; //dim_values / 3;
|
|
|
|
std::vector<int> begin;
|
|
std::vector<int> end;
|
|
std::vector<int> strides;
|
|
|
|
bool isLive = false;
|
|
|
|
std::vector<int> args;
|
|
|
|
// statically evaluated
|
|
if (block.getIArguments()->size() > 5) {
|
|
dim_values = block.getIArguments()->size() - 5;
|
|
delta = dim_values % 3;
|
|
elements = dim_values / 3;
|
|
|
|
for (int e = 5; e < block.getIArguments()->size(); e++)
|
|
args.emplace_back(INT_ARG(e));
|
|
|
|
REQUIRE_TRUE(delta == 0, 0, "StridedSlice: Number of Integer arguments should be equal to input rank x 3 = %i, but got %i instead", (x->rankOf() * 3), dim_values);
|
|
|
|
ShapeUtils::copyVectorPart(begin, args, elements, 0);
|
|
ShapeUtils::copyVectorPart(end, args, elements, elements);
|
|
ShapeUtils::copyVectorPart(strides, args, elements, elements * 2);
|
|
|
|
} else if (block.width() > 1) {
|
|
isLive = true;
|
|
|
|
auto v_begin = INPUT_VARIABLE(1);
|
|
auto v_end = INPUT_VARIABLE(2);
|
|
|
|
elements = v_begin->lengthOf();
|
|
|
|
REQUIRE_TRUE(v_begin->lengthOf() == v_end->lengthOf(), 0, "StridedSlice: Length of begin/end should match, but got %i vs %i instead", (int) v_begin->lengthOf(), (int) v_end->lengthOf());
|
|
REQUIRE_TRUE((v_begin->rankOf() == 1 ) && (v_begin->rankOf() == v_end->rankOf()), 0, "StridedSlice: Rank of begin and ends should be 1, but %i given instead", (int)v_end->rankOf());
|
|
|
|
for (int e = 0; e < v_begin->lengthOf(); e++)
|
|
begin.emplace_back(v_begin->e<int>(e));
|
|
|
|
for (int e = 0; e < v_end->lengthOf(); e++)
|
|
end.emplace_back(v_end->e<int>(e));
|
|
|
|
if (block.width() > 3) {
|
|
auto v_stride = INPUT_VARIABLE(3);
|
|
|
|
REQUIRE_TRUE(v_stride->lengthOf() == v_begin->lengthOf(), 0, "StridedSlice: Length of begin/end/stride should match, but got %i vs %i vs %i instead", (int) v_begin->lengthOf(), (int) v_end->lengthOf(), (int) v_stride->lengthOf());
|
|
REQUIRE_TRUE((v_begin->rankOf() == v_stride->rankOf()), 0, "StridedSlice: Rank of begin and ends should be %i, but %i given instead", (int) v_begin->rankOf(), v_stride->rankOf());
|
|
|
|
for (int e = 0; e < v_stride->lengthOf(); e++)
|
|
strides.emplace_back(v_stride->e<int>(e));
|
|
} else {
|
|
for (int e = 0; e < v_begin->lengthOf(); e++)
|
|
strides.emplace_back(1);
|
|
}
|
|
} else {
|
|
REQUIRE_TRUE(false, 0, "StridedSlice: Can't find begin/end/stride information neither in IArguments or in input arrays");
|
|
}
|
|
|
|
// validation of begin and start
|
|
std::vector<int> ignoreBegin = BitwiseUtils::valueBits(begin_mask);
|
|
std::vector<int> ignoreEnd = BitwiseUtils::valueBits(end_mask);
|
|
std::vector<int> addAxes = BitwiseUtils::valueBits(new_axis_mask);
|
|
std::vector<int> moveAxes = BitwiseUtils::valueBits(shrink_axis_mask);
|
|
if (shrink_axis_mask == 0)
|
|
for (int dim = 0, b = 0, e = 0; dim < x->rankOf(); ++dim) {
|
|
|
|
if(moveAxes[dim])
|
|
continue;
|
|
|
|
if(b < begin.size() && !ignoreBegin[b] && !addAxes[dim]) {
|
|
int first = strides[b] > 0 ? begin[b] : math::nd4j_abs<int>(begin[b]) - 1;
|
|
REQUIRE_TRUE(first <= x->sizeAt(dim), 0, "StridedSlice: begin index should be <= corresponding dimension of input array, but got end_index = %i for dimension %i!", begin[b], dim);
|
|
}
|
|
if(e < end.size() && !ignoreEnd[e] && !addAxes[dim]) {
|
|
int last = strides[e] > 0 ? end[e] : math::nd4j_abs<int>(end[e]) - 1;
|
|
REQUIRE_TRUE(last <= x->sizeAt(dim), 0, "StridedSlice: end index should be <= corresponding dimension of input array, but got end_index = %i for dimension %i!", end[e], dim);
|
|
}
|
|
++b;
|
|
++e;
|
|
}
|
|
|
|
|
|
std::vector<Nd4jLong> indices;
|
|
auto input_shape = x->getShapeAsVector();
|
|
std::vector<Nd4jLong> final_shape;
|
|
bool is_identity;
|
|
bool is_simple_slice;
|
|
bool is_dim0;
|
|
|
|
// FIXME: remove this method once we get 1D vectors supported
|
|
//vectorize(input_shape);
|
|
REQUIRE_TRUE(_preprocess_strided_slice(&indices, &final_shape, input_shape, begin, end, strides, begin_mask, ellipsis_mask, end_mask, new_axis_mask, shrink_axis_mask, &is_identity, &is_simple_slice, &is_dim0), 0, "StridedSlice: shape calculation failed");
|
|
// if(z->lengthOf() == 1 && !z->isEmpty() && (input_shape.size() == 2 && input_shape[0] == 1)) { //(indices.size() == 6) && (indices[2] - indices[0] == 1)) {
|
|
// z->assign(x->e<float>(indices[0]));
|
|
// }
|
|
// else {
|
|
if (indices.size()) {
|
|
Nd4jLong* subArrShapeInfo = nullptr;
|
|
ALLOCATE(subArrShapeInfo, block.getWorkspace(), shape::shapeInfoLength(x->rankOf()), Nd4jLong);
|
|
Nd4jLong offset;
|
|
|
|
shape::calcSubArrShapeInfoAndOffset(indices.data(), x->shapeInfo(), subArrShapeInfo, offset, true, true);
|
|
auto subArrShapeInfoPack = ConstantShapeHelper::getInstance().bufferForShapeInfo(subArrShapeInfo);
|
|
|
|
NDArray::prepareSpecialUse({z}, {x});
|
|
|
|
NativeOpExecutioner::execTransformAny(block.launchContext(), sd::transform::Assign,
|
|
x->bufferWithOffset(offset), subArrShapeInfoPack.primary(),
|
|
x->specialBufferWithOffset(offset), subArrShapeInfoPack.special(),
|
|
z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(),
|
|
nullptr, nullptr, nullptr, true);
|
|
|
|
NDArray::registerSpecialUse({z}, {x});
|
|
|
|
RELEASE(subArrShapeInfo, block.getWorkspace());
|
|
}
|
|
else if (!z->isEmpty()) {
|
|
z->assign(x->e(0));
|
|
}
|
|
return Status::OK();
|
|
}
|
|
DECLARE_SYN(stridedslice, strided_slice);
|
|
|
|
DECLARE_SHAPE_FN(strided_slice) {
|
|
auto inShape = inputShape->at(0);
|
|
|
|
int begin_mask = INT_ARG(0);
|
|
int ellipsis_mask = INT_ARG(1);
|
|
int end_mask = INT_ARG(2);
|
|
int new_axis_mask = INT_ARG(3);
|
|
int shrink_axis_mask = INT_ARG(4);
|
|
|
|
int x_rank = shape::rank(inShape);
|
|
|
|
int dim_values = block.getIArguments()->size() - 5;
|
|
int delta = dim_values % 3;
|
|
int elements = dim_values / 3;
|
|
|
|
|
|
std::vector<int> begin;
|
|
std::vector<int> end;
|
|
std::vector<int> strides;
|
|
|
|
// if that's live - shape will be resolved in runtime
|
|
if (block.width() > 1) {
|
|
begin = INPUT_VARIABLE(1)->template asVectorT<int>();
|
|
end = INPUT_VARIABLE(2)->template asVectorT<int>();
|
|
strides = INPUT_VARIABLE(3)->template asVectorT<int>();
|
|
} else if (dim_values > 0) {
|
|
int delta2 = dim_values / x_rank;
|
|
|
|
std::vector<int> args;
|
|
for (int e = 5; e < block.getIArguments()->size(); e++)
|
|
args.emplace_back(INT_ARG(e));
|
|
|
|
// FIXME: probably template required here
|
|
ShapeUtils::copyVectorPart(begin, args, elements, 0);
|
|
ShapeUtils::copyVectorPart(end, args, elements, elements);
|
|
ShapeUtils::copyVectorPart(strides, args, elements, elements * 2);
|
|
}
|
|
|
|
REQUIRE_TRUE(begin.size() > 0 && end.size() > 0 && strides.size() > 0, 0, "Strided_Slice: empty arguments");
|
|
|
|
// validation of begin and start
|
|
std::vector<int> ignoreBegin = BitwiseUtils::valueBits(begin_mask);
|
|
std::vector<int> ignoreEnd = BitwiseUtils::valueBits(end_mask);
|
|
std::vector<int> addAxes = BitwiseUtils::valueBits(new_axis_mask);
|
|
std::vector<int> moveAxes = BitwiseUtils::valueBits(shrink_axis_mask);
|
|
|
|
//if (0 == shrink_axis_mask)
|
|
if (false)
|
|
for (int dim = 0, b = 0, e = 0; dim < x_rank; ++dim) {
|
|
|
|
if(moveAxes[dim])
|
|
continue;
|
|
|
|
if(b < begin.size() && !ignoreBegin[b] && !addAxes[dim]) {
|
|
int first = strides[b] > 0 ? begin[b] : math::nd4j_abs<int>(begin[b]) - 1;
|
|
REQUIRE_TRUE(first <= inShape[dim + 1], 0, "StridedSlice: begin index should be <= corresponding dimension of input array, but got end_index = %i for dimension %i!", begin[b], dim);
|
|
}
|
|
if(e < end.size() && !ignoreEnd[e] && !addAxes[dim]) {
|
|
int last = strides[e] > 0 ? end[e] : math::nd4j_abs<int>(end[e]) - 1;
|
|
REQUIRE_TRUE(last <= inShape[dim + 1], 0, "StridedSlice: end index should be <= corresponding dimension of input array, but got end_index = %i for dimension %i!", end[e], dim);
|
|
}
|
|
++b;
|
|
++e;
|
|
}
|
|
|
|
std::vector<Nd4jLong> input_shape; //(shape::rank(inShape));
|
|
auto inputLen = shape::length(inShape);
|
|
std::vector<Nd4jLong> shape;
|
|
|
|
auto rank = shape::rank(inShape);
|
|
auto shortShape = shape::shapeOf(inShape);
|
|
for (auto e = 0; e < rank; e++)
|
|
input_shape.emplace_back(shortShape[e]);
|
|
|
|
bool is_identity;
|
|
bool is_simple_slice;
|
|
bool is_dim0;
|
|
|
|
std::vector<Nd4jLong> indices;
|
|
bool result = _preprocess_strided_slice(&indices, &shape, input_shape, begin, end, strides, begin_mask, ellipsis_mask, end_mask, new_axis_mask, shrink_axis_mask, &is_identity, &is_simple_slice, &is_dim0);
|
|
if (indices.size()) {
|
|
auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inShape), 'c',
|
|
shape);
|
|
// if (inputLen > 1) {
|
|
// newShape = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inShape), 'c',
|
|
// shape);
|
|
// } else {
|
|
// newShape = ConstantShapeHelper::getInstance().scalarShapeInfo(ArrayOptions::dataType(inShape));
|
|
// }
|
|
nd4j_printf("Returning new shape %d\n",0);
|
|
return SHAPELIST(newShape);
|
|
}
|
|
|
|
nd4j_printf("Returning empty shape info %d\n",0);
|
|
return SHAPELIST(ConstantShapeHelper::getInstance().emptyShapeInfo(ArrayOptions::dataType(inShape)));
|
|
}
|
|
|
|
|
|
CUSTOM_OP_IMPL(strided_slice_bp, 2, 1, false, 0, 5) {
|
|
auto x = INPUT_VARIABLE(0);
|
|
auto epsNext = INPUT_VARIABLE(1);
|
|
auto output = OUTPUT_VARIABLE(0);
|
|
|
|
int begin_mask = INT_ARG(0);
|
|
int ellipsis_mask = INT_ARG(1);
|
|
int end_mask = INT_ARG(2);
|
|
int new_axis_mask = INT_ARG(3);
|
|
int shrink_axis_mask = INT_ARG(4);
|
|
|
|
int dim_values = 0; //block.getIArguments()->size() - 5;
|
|
int delta = 0; //dim_values % 3;
|
|
int elements = 0; //dim_values / 3;
|
|
|
|
std::vector<int> begin;
|
|
std::vector<int> end;
|
|
std::vector<int> strides;
|
|
|
|
bool isLive = false;
|
|
|
|
std::vector<int> args;
|
|
|
|
// statically evaluated
|
|
if (block.getIArguments()->size() > 5) {
|
|
dim_values = block.getIArguments()->size() - 5;
|
|
delta = dim_values % 3;
|
|
elements = dim_values / 3;
|
|
|
|
for (int e = 5; e < block.getIArguments()->size(); e++)
|
|
args.emplace_back(INT_ARG(e));
|
|
|
|
REQUIRE_TRUE(delta == 0, 0, "StridedSliceBP: Number of Integer arguments should be equal to input rank x 3 = %i, but got %i instead", (x->rankOf() * 3), dim_values);
|
|
|
|
ShapeUtils::copyVectorPart(begin, args, elements, 0);
|
|
ShapeUtils::copyVectorPart(end, args, elements, elements);
|
|
ShapeUtils::copyVectorPart(strides, args, elements, elements * 2);
|
|
|
|
} else if (block.width() >= 3) {
|
|
isLive = true;
|
|
|
|
auto v_begin = INPUT_VARIABLE(2);
|
|
auto v_end = INPUT_VARIABLE(3);
|
|
|
|
elements = v_begin->lengthOf();
|
|
|
|
REQUIRE_TRUE(v_begin->lengthOf() == v_end->lengthOf(), 0, "StridedSliceBP: Length of begin/end should match, but got %i vs %i instead", (int) v_begin->lengthOf(), (int) v_end->lengthOf());
|
|
|
|
for (int e = 0; e < v_begin->lengthOf(); e++)
|
|
begin.emplace_back(v_begin->e<int>(e));
|
|
|
|
for (int e = 0; e < v_end->lengthOf(); e++)
|
|
end.emplace_back(v_end->e<int>(e));
|
|
|
|
if (block.width() >= 4) {
|
|
auto v_stride = INPUT_VARIABLE(4);
|
|
|
|
REQUIRE_TRUE(v_stride->lengthOf() == v_begin->lengthOf(), 0, "StridedSliceBP: Length of begin/end/stride should match, but got %i vs %i vs %i instead", (int) v_begin->lengthOf(), (int) v_end->lengthOf(), (int) v_stride->lengthOf());
|
|
|
|
for (int e = 0; e < v_stride->lengthOf(); e++)
|
|
strides.emplace_back(v_stride->e<int>(e));
|
|
} else {
|
|
for (int e = 0; e < v_begin->lengthOf(); e++)
|
|
strides.emplace_back(1);
|
|
}
|
|
} else {
|
|
REQUIRE_TRUE(false, 0, "StridedSliceBP: Can't find begin/end/stride information neither in IArguments or in input arrays");
|
|
}
|
|
|
|
// validation of begin and start
|
|
std::vector<int> ignoreBegin = BitwiseUtils::valueBits(begin_mask);
|
|
std::vector<int> ignoreEnd = BitwiseUtils::valueBits(end_mask);
|
|
std::vector<int> addAxes = BitwiseUtils::valueBits(new_axis_mask);
|
|
std::vector<int> moveAxes = BitwiseUtils::valueBits(shrink_axis_mask);
|
|
|
|
for (int dim = 0, b = 0, e = 0; dim < x->rankOf(); ++dim) {
|
|
|
|
if(moveAxes[dim])
|
|
continue;
|
|
|
|
if(b < begin.size() && !ignoreBegin[b] && !addAxes[dim]) {
|
|
int first = strides[b] > 0 ? begin[b] : math::nd4j_abs<int>(begin[b]) - 1;
|
|
REQUIRE_TRUE(first <= x->sizeAt(dim), 0, "StridedSlice: begin index should be <= corresponding dimension of input array, but got end_index = %i for dimension %i!", begin[b], dim);
|
|
}
|
|
if(e < end.size() && !ignoreEnd[e] && !addAxes[dim]) {
|
|
int last = strides[e] > 0 ? end[e] : math::nd4j_abs<int>(end[e]) - 1;
|
|
REQUIRE_TRUE(last <= x->sizeAt(dim), 0, "StridedSlice: end index should be <= corresponding dimension of input array, but got end_index = %i for dimension %i!", end[e], dim);
|
|
}
|
|
++b;
|
|
++e;
|
|
}
|
|
|
|
auto input_shape = x->getShapeAsVector();
|
|
std::vector<Nd4jLong> indices;
|
|
std::vector<Nd4jLong> final_shape;
|
|
bool is_identity;
|
|
bool is_simple_slice;
|
|
bool is_dim0;
|
|
|
|
// FIXME: remove this method once we get 1D vectors supported
|
|
vectorize(input_shape);
|
|
REQUIRE_TRUE(_preprocess_strided_slice(&indices, &final_shape, input_shape, begin, end, strides, begin_mask, ellipsis_mask, end_mask, new_axis_mask, shrink_axis_mask, &is_identity, &is_simple_slice, &is_dim0), 0, "StridedSliceBP: shape calculation failed");
|
|
//REQUIRE_TRUE(epsNext->isSameShape(final_shape), 0, "StridedSlice_bp: gradOut shape should be equals to output from strided_slice op.");
|
|
//Zero output array, so unused elements have 0 gradient
|
|
output->nullify();
|
|
//
|
|
// the first case: only for scalar gradient step
|
|
if(epsNext->lengthOf() == 1 && (indices.size() == 3 && (indices[1] - indices[0]) == 1 || (indices[2] - indices[0] == 1))) {
|
|
output->p(indices[0], *epsNext);
|
|
}
|
|
else { // else for other cases
|
|
auto sub = (*output)(indices, true, true);
|
|
sub.assign(epsNext);
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
DECLARE_SHAPE_FN(strided_slice_bp) {
|
|
auto inShape = inputShape->at(0);
|
|
Nd4jLong *newShape;
|
|
COPY_SHAPE(inShape, newShape);
|
|
|
|
return SHAPELIST(CONSTANT(newShape));
|
|
}
|
|
|
|
DECLARE_TYPES(strided_slice) {
|
|
getOpDescriptor()
|
|
->setAllowedInputTypes(sd::DataType::ANY)
|
|
->setSameMode(true);
|
|
}
|
|
|
|
DECLARE_TYPES(strided_slice_bp) {
|
|
getOpDescriptor()
|
|
->setAllowedInputTypes(sd::DataType::ANY)
|
|
->setAllowedOutputTypes({ALL_FLOATS});
|
|
}
|
|
}
|
|
}
|
|
|
|
#endif
|