/* ****************************************************************************** * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * See the NOTICE file distributed with this work for additional * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author raver119@gmail.com // #include #if NOT_EXCLUDED(OP_split_v) #include namespace sd { namespace ops { CUSTOM_OP_IMPL(split_v, 2, -1, false, 0, -2) { auto input = INPUT_VARIABLE(0); auto sizes = INPUT_VARIABLE(1); int axis = 0; if (block.getIArguments()->size() > 0) { axis = INT_ARG(0); } else if (block.width() > 2){ auto _a = INPUT_VARIABLE(2); axis = _a->e(0); } if (axis < 0) axis += input->rankOf(); std::vector dims = ShapeUtils::evalDimsToExclude(input->rankOf(), {axis}); int pos = 0; std::vector indices(2 * input->rankOf()); for (Nd4jLong e = 0; e < sizes->lengthOf(); e++) { int c_size = sizes->e(e); for (int d = 0; d < input->rankOf(); d++) { if (d == axis) indices[2*d + 1] = (indices[2*d] = pos) + c_size; else indices[2*d] = indices[2*d + 1] = 0; } auto output = OUTPUT_VARIABLE(e); REQUIRE_TRUE(output->dataType() == input->dataType(), 0, "SplitV: all outputs must have same data type as input"); auto sub = (*input)(indices); output->assign(sub); pos += c_size; } //delete tads; return Status::OK(); } DECLARE_TYPES(split_v) { getOpDescriptor() ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) ->setAllowedInputTypes(1, {ALL_INTS}) ->setAllowedInputTypes(2, {ALL_INTS}) ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); } DECLARE_SHAPE_FN(split_v) { auto input = inputShape->at(0); //auto sizes = inputShape->at(1); auto shapeList = SHAPELIST(); int rank = shape::rank(input); // 0 is just default axis int axis = 0; if (block.getIArguments()->size() > 0) axis = INT_ARG(0); else if (block.width() > 2) { auto _a = INPUT_VARIABLE(2); axis = _a->e(0); } if (axis < 0) axis += shape::rank(input); // this op assumes we have sizes defined auto sizes = INPUT_VARIABLE(1); auto length = sizes->lengthOf(); int pos = 0; for (Nd4jLong e = 0; e < length; e++) { int c_size = sizes->e(e); std::vector shape(rank); for (int d = 0; d < rank; d++) { if (d != axis) shape[d] = shape::sizeAt(input, d); else shape[d] = c_size; } auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(ArrayOptions::dataType(input), shape::order(input), shape); shapeList->push_back(newShape); } return shapeList; } } } #endif