/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author GS // #include #include namespace nd4j { namespace ops { namespace helpers { Nd4jStatus bdsFunctor(nd4j::LaunchContext * context, NDArray* x_shape, NDArray* y_shape, NDArray* output) { if (x_shape->lengthOf() == 1 || y_shape->lengthOf() == 1) {// except case // lenght are equals if (x_shape->lengthOf() == y_shape->lengthOf()) { auto greater = (x_shape->e(0) < y_shape->e(0) ? y_shape : x_shape); output->assign(greater); } else { auto lesser = (x_shape->lengthOf() == 1 ? x_shape : y_shape); auto greater = (x_shape->lengthOf() == 1 ? y_shape : x_shape); output->assign(greater); auto lastG = greater->lengthOf() - 1; auto lastL = lesser->lengthOf() - 1; if (greater->e(lastG) < lesser->e(lastL)) output->p(lastG, lesser->e(lastL)); } } else { //int e = 0, x = 0, y = 0; Nd4jLong xLen = x_shape->lengthOf(); Nd4jLong yLen = y_shape->lengthOf(); Nd4jLong zLen = output->lengthOf(); Nd4jLong borderLen = nd4j::math::nd4j_min(xLen, yLen); for (Nd4jLong e = 0; e < zLen; e++) { Nd4jLong val; if (e < borderLen) { val = nd4j::math::nd4j_max(x_shape->e(e), y_shape->e(e)); } else if (e < xLen) { val = nd4j::math::nd4j_max(x_shape->e(e), y_shape->e(yLen - 1)); } else { val = nd4j::math::nd4j_max(x_shape->e(xLen - 1), y_shape->e(e)); } // if (e) // if (val != output->e(e - 1)) { // nd4j_printf( // "broadcast_dynamic_shape: Input shapes should be compatible, but %lld and %lld were given.\n", // val, output->e(e - 1)); // return Status::CODE(ND4J_STATUS_VALIDATION, "broadcast_dynamic_shape: BDS validation failed!"); // } output->p(e, val); } } return Status::OK(); } } } }