/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

//
//  @author raver119@gmail.com
//

#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_tensormmul)

#include <helpers/ShapeUtils.h>
#include <ops/declarable/CustomOperations.h>
#include <MmulHelper.h>

namespace nd4j {
    namespace ops {
        CUSTOM_OP_IMPL(tensormmul, 2, 1, false, 0, -1) {
            auto a = INPUT_VARIABLE(0);
            auto b = INPUT_VARIABLE(1);

            auto c = OUTPUT_VARIABLE(0); //

            REQUIRE_TRUE(a->dataType() == b->dataType(), 0, "tensormmul: A, B and C data types must be the same");

            // building axes
            int axe0_size = INT_ARG(0);
            int axe1_size = INT_ARG(axe0_size+1);
            std::vector<int> axes_0(axe0_size), axes_1(axe1_size);
            for (int e = 0; e < axe0_size; e++)
                axes_0[e] = (int) INT_ARG(e+1);

            for (int e = 0; e < axe1_size; e++)
                axes_1[e] = (int) INT_ARG(e + axe0_size + 2);

            nd4j_verbose("axe0: %i; axe1: %i;\n", axes_0.size(), axes_1.size());

            MmulHelper::tensorDot(a, b, c, axes_0, axes_1);
            return Status::OK();
        }
        DECLARE_SYN(tensordot, tensormmul);


        DECLARE_SHAPE_FN(tensormmul) {               
        
            auto aShapeInfo = inputShape->at(0);
            auto bShapeInfo = inputShape->at(1);

            REQUIRE_TRUE(ArrayOptions::dataType(aShapeInfo) == ArrayOptions::dataType(bShapeInfo), 0, "tensormmul: A and B data types must be the same");

            // building axes
            int axe0_size = INT_ARG(0);
            int axe1_size = INT_ARG(axe0_size+1);
            std::vector<int> axes_0(axe0_size), axes_1(axe1_size);
            for (int e = 0; e < axe0_size; e++)
                axes_0[e] = (int) INT_ARG(e+1);

            for (int e = 0; e < axe1_size; e++)
                axes_1[e] = (int) INT_ARG(e + axe0_size + 2);

            // evaluate shapes 
            std::vector<int> permutAt, permutBt;
            std::vector<Nd4jLong> shapeAt, shapeBt;
            auto outShape = nd4j::ShapeUtils::evalShapeForTensorDot(aShapeInfo, bShapeInfo, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt);

            return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(aShapeInfo), 'c', outShape)));
        }

        DECLARE_TYPES(tensormmul) {
            getOpDescriptor()
                    ->setAllowedInputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF})
                    ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF})
                    ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF});
        }
    }
}

#endif