/* * ****************************************************************************** * * * * * * This program and the accompanying materials are made available under the * * terms of the Apache License, Version 2.0 which is available at * * https://www.apache.org/licenses/LICENSE-2.0. * * * * See the NOTICE file distributed with this work for additional * * information regarding copyright ownership. * * Unless required by applicable law or agreed to in writing, software * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * * License for the specific language governing permissions and limitations * * under the License. * * * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ #include #if NOT_EXCLUDED(OP_strided_slice) #include #include #include #include namespace sd { namespace ops { constexpr int kShrinkAxis = -1, kNewAxis = -2; struct StridedSliceSparseSpec { int dims; int num_add_axis_after_ellipsis; std::vector* begin_tensor; const std::vector* end_tensor; const std::vector* strides_tensor; const int begin_mask, end_mask; int ellipsis_mask; const int new_axis_mask, shrink_axis_mask; }; struct StridedSliceDenseSpec { const int dims; int begin_mask; int end_mask; bool begin_valid; bool end_valid; std::vector& begin; std::vector& end; std::vector& strides; std::vector final_shape_gather_indices; int shrink_axis_mask; public: bool buildDenseSpec(StridedSliceSparseSpec& sparse_spec) { if (this->begin.size() < dims) this->begin.resize(dims); if (this->end.size() < dims) this->end.resize(dims); if (this->strides.size() < dims) this->strides.resize(dims); this->begin_mask = 0; this->end_mask = 0; this->shrink_axis_mask = 0; { int full_index = 0; this->begin_valid = sparse_spec.begin_tensor != nullptr; this->end_valid = sparse_spec.end_tensor != nullptr; for (int e = 0; e < sparse_spec.dims; e++) { if ((1 << e) & sparse_spec.ellipsis_mask) { int next_index = sd::math::nd4j_min(this->dims - (sparse_spec.dims - e) + 1 + sparse_spec.num_add_axis_after_ellipsis, this->dims); for (; full_index < next_index; full_index++) { // new_axis' aren't real axis so you have to skip this->begin[full_index] = this->end[full_index] = 0; this->strides[full_index] = 1; this->begin_mask |= (1 << full_index); this->end_mask |= (1 << full_index); this->final_shape_gather_indices.push_back(full_index); } } else if ((1 << e) & sparse_spec.new_axis_mask) { this->final_shape_gather_indices.emplace_back(kNewAxis); } else { if (full_index == this->begin.size()) { nd4j_printf("Index out of range: %i out of %i\n", full_index, this->dims); return false; } // Gather slicing spec into appropriate index if (sparse_spec.begin_tensor != nullptr) this->begin[full_index] = sparse_spec.begin_tensor->at(e); if (sparse_spec.end_tensor != nullptr) this->end[full_index] = sparse_spec.end_tensor->at(e); this->strides[full_index] = sparse_spec.strides_tensor->at(e); if (sparse_spec.begin_mask & (1 << e)) this->begin_mask |= (1 << full_index); if (sparse_spec.end_mask & (1 << e)) this->end_mask |= (1 << full_index); // If shrink, record where to get the dimensionality from (i.e. // new_axis creates a fake 1 size dimension. Also remember shrink // axis (now in dense form) so we can ignore dense->end below. if (sparse_spec.shrink_axis_mask & (1 << e)) { this->final_shape_gather_indices.push_back(kShrinkAxis); this->shrink_axis_mask |= (1 << full_index); } else { this->final_shape_gather_indices.push_back(full_index); } full_index++; } } } return true; } }; void vectorize(std::vector& input_shape) { if (input_shape.size() == 2 && input_shape[0] == 1) { int v = input_shape[1]; input_shape.clear(); input_shape.emplace_back(v); } } bool _preprocess_strided_slice(std::vector* indicesList, std::vector* final_shape, std::vector& input_shape, std::vector& begin, std::vector& end, std::vector& strides, int begin_mask, int ellipsis_mask, int end_mask, int new_axis_mask, int shrink_axis_mask, bool* is_identity, bool* is_simple_slice, bool* slice_dim0) { std::vector preshape; bool ellipsis_seen = false; StridedSliceSparseSpec sparse_spec = {(int) strides.size(), 0, &begin, &end, &strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask}; for (int i = 0; i < sparse_spec.dims; i++) { if (ellipsis_seen && ((1 << i) & new_axis_mask) != 0) { sparse_spec.num_add_axis_after_ellipsis++; } if ((1 << i) & ellipsis_mask) { ellipsis_seen = true; } } // If no ellipsis insert one at the end if (!ellipsis_seen) { sparse_spec.ellipsis_mask |= (1 << sparse_spec.dims); sparse_spec.dims++; // this effects loop iteration below } StridedSliceDenseSpec dense_spec = {(int) input_shape.size(), 0, 0, false, false, begin, end, strides}; if (!dense_spec.buildDenseSpec(sparse_spec)) return false; //nd4j_printv("Input shape: ", input_shape); for (int e = 0; e < (int) input_shape.size(); e++) { int begin_idx = begin[e]; int end_idx = end[e]; int stride_idx = strides[e]; int size_idx = input_shape[e]; bool shrink_i = (dense_spec.shrink_axis_mask & (1 << e)); if (stride_idx == 0) { nd4j_printf("Stride is 0 at index %i\n", e); return false; } if (size_idx == -1) { preshape.emplace_back(shrink_i ? 1 : -1); continue; } const std::array masks = {{dense_spec.begin_mask & (1 << e), dense_spec.end_mask & (1 << e)}}; const std::array valid_range = {{stride_idx > 0 ? 0 : -1, stride_idx > 0 ? size_idx : size_idx - 1}}; auto canonical = [stride_idx, e, size_idx, masks, valid_range](int x, int c) { if (masks[c]) { return stride_idx > 0 ? valid_range[c] : valid_range[(c + 1) & 1]; } else { int x_fwd = x < 0 ? size_idx + x : x; // make negative indices positive return x_fwd < valid_range[0] ? valid_range[0] : x_fwd > valid_range[1] ? valid_range[1] : x_fwd; } }; if (shrink_i && stride_idx <= 0) { nd4j_printf("StridedSlice: only stride 1 allowed on non-range indexing\n", e); return false; } (*is_simple_slice) &= stride_idx == 1; const bool begin_and_end_masked = (begin_mask & (1 << e)) && (end_mask & (1 << e)); if (dense_spec.begin_valid && dense_spec.end_valid) { if (shrink_i) { int x_fwd = begin_idx < 0 ? size_idx + begin_idx : begin_idx; begin_idx = x_fwd; end_idx = begin_idx + 1; if (x_fwd < 0 || x_fwd >= size_idx) { nd4j_printf("slice index %i of dimension %i out of bounds.\n", begin_idx, e); return false; } } else { begin_idx = canonical(begin_idx, 0); end_idx = canonical(end_idx, 1); } } else { (*is_identity) &= stride_idx == 1 && begin_and_end_masked; (*slice_dim0) &= (e == 0 && stride_idx == 1) || begin_and_end_masked; } int interval_length = 1; bool known_interval = false; if (dense_spec.begin_valid && dense_spec.end_valid) { interval_length = end_idx - begin_idx; known_interval = true; } else if (shrink_i) { interval_length = 1; known_interval = true; } else if (begin_and_end_masked) { if (size_idx > 0) { if (stride_idx < 0) { interval_length = -size_idx; } else { interval_length = size_idx; } known_interval = true; } } if (known_interval) { int size_i; if (interval_length == 0 || ((interval_length < 0) != (stride_idx < 0))) { size_i = input_shape.size() == 2 && input_shape[0] == 1? 1: 0; } else { size_i = interval_length / stride_idx + (interval_length % stride_idx != 0 ? 1 : 0); } if (indicesList != nullptr) { if (interval_length > 1) { indicesList->push_back(begin_idx); indicesList->push_back(end_idx); indicesList->push_back(stride_idx); // (*indicesList)[3*e] = begin_idx; // (*indicesList)[3*e+1] = end_idx; // (*indicesList)[3*e+2] = stride_idx; } else if (interval_length == 1) { indicesList->push_back(begin_idx); indicesList->push_back(begin_idx + 1); indicesList->push_back(1); // (*indicesList)[3*e] = begin_idx; // (*indicesList)[3*e+1] = begin_idx + 1; // (*indicesList)[3*e+2] = 1; } } preshape.emplace_back(size_i); } else { preshape.emplace_back(-1); } } std::vector postshape; //nd4j_printv("Preshape: ", preshape); final_shape->clear(); for (auto gather_index : dense_spec.final_shape_gather_indices) { if (gather_index >= 0) { if (preshape.size() > gather_index) final_shape->emplace_back(preshape.at(gather_index)); else final_shape->emplace_back(1); } else if (gather_index == kNewAxis) { final_shape->emplace_back(1); } } //nd4j_printv("Preshape: ", preshape); //nd4j_printv("Postshape: ", *final_shape); return true; } CUSTOM_OP_IMPL(strided_slice, 1, 1, false, 0, 5) { auto x = INPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0); if (z->isEmpty() || z->lengthOf() == 0) { return ND4J_STATUS_OK; } int begin_mask = INT_ARG(0); int ellipsis_mask = INT_ARG(1); int end_mask = INT_ARG(2); int new_axis_mask = INT_ARG(3); int shrink_axis_mask = INT_ARG(4); int dim_values = 0; //block.getIArguments()->size() - 5; int delta = 0; //dim_values % 3; int elements = 0; //dim_values / 3; std::vector begin; std::vector end; std::vector strides; bool isLive = false; std::vector args; // statically evaluated if (block.getIArguments()->size() > 5) { dim_values = block.getIArguments()->size() - 5; delta = dim_values % 3; elements = dim_values / 3; for (int e = 5; e < block.getIArguments()->size(); e++) args.emplace_back(INT_ARG(e)); REQUIRE_TRUE(delta == 0, 0, "StridedSlice: Number of Integer arguments should be equal to input rank x 3 = %i, but got %i instead", (x->rankOf() * 3), dim_values); ShapeUtils::copyVectorPart(begin, args, elements, 0); ShapeUtils::copyVectorPart(end, args, elements, elements); ShapeUtils::copyVectorPart(strides, args, elements, elements * 2); } else if (block.width() > 1) { isLive = true; auto v_begin = INPUT_VARIABLE(1); auto v_end = INPUT_VARIABLE(2); elements = v_begin->lengthOf(); REQUIRE_TRUE(v_begin->lengthOf() == v_end->lengthOf(), 0, "StridedSlice: Length of begin/end should match, but got %i vs %i instead", (int) v_begin->lengthOf(), (int) v_end->lengthOf()); REQUIRE_TRUE((v_begin->rankOf() == 1 ) && (v_begin->rankOf() == v_end->rankOf()), 0, "StridedSlice: Rank of begin and ends should be 1, but %i given instead", (int)v_end->rankOf()); for (int e = 0; e < v_begin->lengthOf(); e++) begin.emplace_back(v_begin->e(e)); for (int e = 0; e < v_end->lengthOf(); e++) end.emplace_back(v_end->e(e)); if (block.width() > 3) { auto v_stride = INPUT_VARIABLE(3); REQUIRE_TRUE(v_stride->lengthOf() == v_begin->lengthOf(), 0, "StridedSlice: Length of begin/end/stride should match, but got %i vs %i vs %i instead", (int) v_begin->lengthOf(), (int) v_end->lengthOf(), (int) v_stride->lengthOf()); REQUIRE_TRUE((v_begin->rankOf() == v_stride->rankOf()), 0, "StridedSlice: Rank of begin and ends should be %i, but %i given instead", (int) v_begin->rankOf(), v_stride->rankOf()); for (int e = 0; e < v_stride->lengthOf(); e++) strides.emplace_back(v_stride->e(e)); } else { for (int e = 0; e < v_begin->lengthOf(); e++) strides.emplace_back(1); } } else { REQUIRE_TRUE(false, 0, "StridedSlice: Can't find begin/end/stride information neither in IArguments or in input arrays"); } // validation of begin and start std::vector ignoreBegin = BitwiseUtils::valueBits(begin_mask); std::vector ignoreEnd = BitwiseUtils::valueBits(end_mask); std::vector addAxes = BitwiseUtils::valueBits(new_axis_mask); std::vector moveAxes = BitwiseUtils::valueBits(shrink_axis_mask); if (shrink_axis_mask == 0) for (int dim = 0, b = 0, e = 0; dim < x->rankOf(); ++dim) { if(moveAxes[dim]) continue; if(b < begin.size() && !ignoreBegin[b] && !addAxes[dim]) { int first = strides[b] > 0 ? begin[b] : math::nd4j_abs(begin[b]) - 1; REQUIRE_TRUE(first <= x->sizeAt(dim), 0, "StridedSlice: begin index should be <= corresponding dimension of input array, but got end_index = %i for dimension %i!", begin[b], dim); } if(e < end.size() && !ignoreEnd[e] && !addAxes[dim]) { int last = strides[e] > 0 ? end[e] : math::nd4j_abs(end[e]) - 1; REQUIRE_TRUE(last <= x->sizeAt(dim), 0, "StridedSlice: end index should be <= corresponding dimension of input array, but got end_index = %i for dimension %i!", end[e], dim); } ++b; ++e; } std::vector indices; auto input_shape = x->getShapeAsVector(); std::vector final_shape; bool is_identity; bool is_simple_slice; bool is_dim0; // FIXME: remove this method once we get 1D vectors supported //vectorize(input_shape); REQUIRE_TRUE(_preprocess_strided_slice(&indices, &final_shape, input_shape, begin, end, strides, begin_mask, ellipsis_mask, end_mask, new_axis_mask, shrink_axis_mask, &is_identity, &is_simple_slice, &is_dim0), 0, "StridedSlice: shape calculation failed"); // if(z->lengthOf() == 1 && !z->isEmpty() && (input_shape.size() == 2 && input_shape[0] == 1)) { //(indices.size() == 6) && (indices[2] - indices[0] == 1)) { // z->assign(x->e(indices[0])); // } // else { if (indices.size()) { Nd4jLong* subArrShapeInfo = nullptr; ALLOCATE(subArrShapeInfo, block.getWorkspace(), shape::shapeInfoLength(x->rankOf()), Nd4jLong); Nd4jLong offset; shape::calcSubArrShapeInfoAndOffset(indices.data(), x->shapeInfo(), subArrShapeInfo, offset, true, true); auto subArrShapeInfoPack = ConstantShapeHelper::getInstance().bufferForShapeInfo(subArrShapeInfo); NDArray::prepareSpecialUse({z}, {x}); NativeOpExecutioner::execTransformAny(block.launchContext(), sd::transform::Assign, x->bufferWithOffset(offset), subArrShapeInfoPack.primary(), x->specialBufferWithOffset(offset), subArrShapeInfoPack.special(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), nullptr, nullptr, nullptr, true); NDArray::registerSpecialUse({z}, {x}); RELEASE(subArrShapeInfo, block.getWorkspace()); } else if (!z->isEmpty()) { z->assign(x->e(0)); } return Status::OK(); } DECLARE_SYN(stridedslice, strided_slice); DECLARE_SHAPE_FN(strided_slice) { auto inShape = inputShape->at(0); int begin_mask = INT_ARG(0); int ellipsis_mask = INT_ARG(1); int end_mask = INT_ARG(2); int new_axis_mask = INT_ARG(3); int shrink_axis_mask = INT_ARG(4); int x_rank = shape::rank(inShape); int dim_values = block.getIArguments()->size() - 5; int delta = dim_values % 3; int elements = dim_values / 3; std::vector begin; std::vector end; std::vector strides; // if that's live - shape will be resolved in runtime if (block.width() > 1) { begin = INPUT_VARIABLE(1)->template asVectorT(); end = INPUT_VARIABLE(2)->template asVectorT(); strides = INPUT_VARIABLE(3)->template asVectorT(); } else if (dim_values > 0) { int delta2 = dim_values / x_rank; std::vector args; for (int e = 5; e < block.getIArguments()->size(); e++) args.emplace_back(INT_ARG(e)); // FIXME: probably template required here ShapeUtils::copyVectorPart(begin, args, elements, 0); ShapeUtils::copyVectorPart(end, args, elements, elements); ShapeUtils::copyVectorPart(strides, args, elements, elements * 2); } REQUIRE_TRUE(begin.size() > 0 && end.size() > 0 && strides.size() > 0, 0, "Strided_Slice: empty arguments"); // validation of begin and start std::vector ignoreBegin = BitwiseUtils::valueBits(begin_mask); std::vector ignoreEnd = BitwiseUtils::valueBits(end_mask); std::vector addAxes = BitwiseUtils::valueBits(new_axis_mask); std::vector moveAxes = BitwiseUtils::valueBits(shrink_axis_mask); //if (0 == shrink_axis_mask) if (false) for (int dim = 0, b = 0, e = 0; dim < x_rank; ++dim) { if(moveAxes[dim]) continue; if(b < begin.size() && !ignoreBegin[b] && !addAxes[dim]) { int first = strides[b] > 0 ? begin[b] : math::nd4j_abs(begin[b]) - 1; REQUIRE_TRUE(first <= inShape[dim + 1], 0, "StridedSlice: begin index should be <= corresponding dimension of input array, but got end_index = %i for dimension %i!", begin[b], dim); } if(e < end.size() && !ignoreEnd[e] && !addAxes[dim]) { int last = strides[e] > 0 ? end[e] : math::nd4j_abs(end[e]) - 1; REQUIRE_TRUE(last <= inShape[dim + 1], 0, "StridedSlice: end index should be <= corresponding dimension of input array, but got end_index = %i for dimension %i!", end[e], dim); } ++b; ++e; } std::vector input_shape; //(shape::rank(inShape)); auto inputLen = shape::length(inShape); std::vector shape; auto rank = shape::rank(inShape); auto shortShape = shape::shapeOf(inShape); for (auto e = 0; e < rank; e++) input_shape.emplace_back(shortShape[e]); bool is_identity; bool is_simple_slice; bool is_dim0; std::vector indices; bool result = _preprocess_strided_slice(&indices, &shape, input_shape, begin, end, strides, begin_mask, ellipsis_mask, end_mask, new_axis_mask, shrink_axis_mask, &is_identity, &is_simple_slice, &is_dim0); if (indices.size()) { auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inShape), 'c', shape); // if (inputLen > 1) { // newShape = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(inShape), 'c', // shape); // } else { // newShape = ConstantShapeHelper::getInstance().scalarShapeInfo(ArrayOptions::dataType(inShape)); // } return SHAPELIST(newShape); } return SHAPELIST(ConstantShapeHelper::getInstance().emptyShapeInfo(ArrayOptions::dataType(inShape))); } CUSTOM_OP_IMPL(strided_slice_bp, 2, 1, false, 0, 5) { auto x = INPUT_VARIABLE(0); auto epsNext = INPUT_VARIABLE(1); auto output = OUTPUT_VARIABLE(0); int begin_mask = INT_ARG(0); int ellipsis_mask = INT_ARG(1); int end_mask = INT_ARG(2); int new_axis_mask = INT_ARG(3); int shrink_axis_mask = INT_ARG(4); int dim_values = 0; //block.getIArguments()->size() - 5; int delta = 0; //dim_values % 3; int elements = 0; //dim_values / 3; std::vector begin; std::vector end; std::vector strides; bool isLive = false; std::vector args; // statically evaluated if (block.getIArguments()->size() > 5) { dim_values = block.getIArguments()->size() - 5; delta = dim_values % 3; elements = dim_values / 3; for (int e = 5; e < block.getIArguments()->size(); e++) args.emplace_back(INT_ARG(e)); REQUIRE_TRUE(delta == 0, 0, "StridedSliceBP: Number of Integer arguments should be equal to input rank x 3 = %i, but got %i instead", (x->rankOf() * 3), dim_values); ShapeUtils::copyVectorPart(begin, args, elements, 0); ShapeUtils::copyVectorPart(end, args, elements, elements); ShapeUtils::copyVectorPart(strides, args, elements, elements * 2); } else if (block.width() >= 3) { isLive = true; auto v_begin = INPUT_VARIABLE(2); auto v_end = INPUT_VARIABLE(3); elements = v_begin->lengthOf(); REQUIRE_TRUE(v_begin->lengthOf() == v_end->lengthOf(), 0, "StridedSliceBP: Length of begin/end should match, but got %i vs %i instead", (int) v_begin->lengthOf(), (int) v_end->lengthOf()); for (int e = 0; e < v_begin->lengthOf(); e++) begin.emplace_back(v_begin->e(e)); for (int e = 0; e < v_end->lengthOf(); e++) end.emplace_back(v_end->e(e)); if (block.width() >= 4) { auto v_stride = INPUT_VARIABLE(4); REQUIRE_TRUE(v_stride->lengthOf() == v_begin->lengthOf(), 0, "StridedSliceBP: Length of begin/end/stride should match, but got %i vs %i vs %i instead", (int) v_begin->lengthOf(), (int) v_end->lengthOf(), (int) v_stride->lengthOf()); for (int e = 0; e < v_stride->lengthOf(); e++) strides.emplace_back(v_stride->e(e)); } else { for (int e = 0; e < v_begin->lengthOf(); e++) strides.emplace_back(1); } } else { REQUIRE_TRUE(false, 0, "StridedSliceBP: Can't find begin/end/stride information neither in IArguments or in input arrays"); } // validation of begin and start std::vector ignoreBegin = BitwiseUtils::valueBits(begin_mask); std::vector ignoreEnd = BitwiseUtils::valueBits(end_mask); std::vector addAxes = BitwiseUtils::valueBits(new_axis_mask); std::vector moveAxes = BitwiseUtils::valueBits(shrink_axis_mask); for (int dim = 0, b = 0, e = 0; dim < x->rankOf(); ++dim) { if(moveAxes[dim]) continue; if(b < begin.size() && !ignoreBegin[b] && !addAxes[dim]) { int first = strides[b] > 0 ? begin[b] : math::nd4j_abs(begin[b]) - 1; REQUIRE_TRUE(first <= x->sizeAt(dim), 0, "StridedSlice: begin index should be <= corresponding dimension of input array, but got end_index = %i for dimension %i!", begin[b], dim); } if(e < end.size() && !ignoreEnd[e] && !addAxes[dim]) { int last = strides[e] > 0 ? end[e] : math::nd4j_abs(end[e]) - 1; REQUIRE_TRUE(last <= x->sizeAt(dim), 0, "StridedSlice: end index should be <= corresponding dimension of input array, but got end_index = %i for dimension %i!", end[e], dim); } ++b; ++e; } auto input_shape = x->getShapeAsVector(); std::vector indices; std::vector final_shape; bool is_identity; bool is_simple_slice; bool is_dim0; // FIXME: remove this method once we get 1D vectors supported vectorize(input_shape); REQUIRE_TRUE(_preprocess_strided_slice(&indices, &final_shape, input_shape, begin, end, strides, begin_mask, ellipsis_mask, end_mask, new_axis_mask, shrink_axis_mask, &is_identity, &is_simple_slice, &is_dim0), 0, "StridedSliceBP: shape calculation failed"); //REQUIRE_TRUE(epsNext->isSameShape(final_shape), 0, "StridedSlice_bp: gradOut shape should be equals to output from strided_slice op."); //Zero output array, so unused elements have 0 gradient output->nullify(); // // the first case: only for scalar gradient step if(epsNext->lengthOf() == 1 && (indices.size() == 3 && (indices[1] - indices[0]) == 1 || (indices[2] - indices[0] == 1))) { output->p(indices[0], *epsNext); } else { // else for other cases auto sub = (*output)(indices, true, true); sub.assign(epsNext); } return Status::OK(); } DECLARE_SHAPE_FN(strided_slice_bp) { auto inShape = inputShape->at(0); Nd4jLong *newShape; COPY_SHAPE(inShape, newShape); return SHAPELIST(CONSTANT(newShape)); } DECLARE_TYPES(strided_slice) { getOpDescriptor() ->setAllowedInputTypes(sd::DataType::ANY) ->setSameMode(true); } DECLARE_TYPES(strided_slice_bp) { getOpDescriptor() ->setAllowedInputTypes(sd::DataType::ANY) ->setAllowedOutputTypes({ALL_FLOATS}); } } } #endif