/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author raver119@gmail.com // #include #if NOT_EXCLUDED(OP_tensormmul) #include #include #include namespace nd4j { namespace ops { CUSTOM_OP_IMPL(tensormmul, 2, 1, false, 0, -1) { auto a = INPUT_VARIABLE(0); auto b = INPUT_VARIABLE(1); auto c = OUTPUT_VARIABLE(0); // REQUIRE_TRUE(a->dataType() == b->dataType(), 0, "tensormmul: A, B and C data types must be the same"); // building axes int axe0_size = INT_ARG(0); int axe1_size = INT_ARG(axe0_size+1); std::vector axes_0(axe0_size), axes_1(axe1_size); for (int e = 0; e < axe0_size; e++) axes_0[e] = (int) INT_ARG(e+1); for (int e = 0; e < axe1_size; e++) axes_1[e] = (int) INT_ARG(e + axe0_size + 2); nd4j_verbose("axe0: %i; axe1: %i;\n", axes_0.size(), axes_1.size()); MmulHelper::tensorDot(a, b, c, axes_0, axes_1); return Status::OK(); } DECLARE_SYN(tensordot, tensormmul); DECLARE_SHAPE_FN(tensormmul) { auto aShapeInfo = inputShape->at(0); auto bShapeInfo = inputShape->at(1); REQUIRE_TRUE(ArrayOptions::dataType(aShapeInfo) == ArrayOptions::dataType(bShapeInfo), 0, "tensormmul: A and B data types must be the same"); // building axes int axe0_size = INT_ARG(0); int axe1_size = INT_ARG(axe0_size+1); std::vector axes_0(axe0_size), axes_1(axe1_size); for (int e = 0; e < axe0_size; e++) axes_0[e] = (int) INT_ARG(e+1); for (int e = 0; e < axe1_size; e++) axes_1[e] = (int) INT_ARG(e + axe0_size + 2); // evaluate shapes std::vector permutAt, permutBt; std::vector shapeAt, shapeBt; auto outShape = nd4j::ShapeUtils::evalShapeForTensorDot(aShapeInfo, bShapeInfo, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt); return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(aShapeInfo), 'c', outShape))); } DECLARE_TYPES(tensormmul) { getOpDescriptor() ->setAllowedInputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } } } #endif