/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

//
//  @author sgazeos@gmail.com
//

#include <op_boilerplate.h>
#include <NDArray.h>
#include <helpers/ShapeUtils.h>


namespace nd4j {
    namespace ops {
        namespace helpers {

            template <typename T>
            void maximumBPFunctor_(NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY) {

                auto lambdaX = LAMBDA_TTT(_e, _x, _y) {
                    return _x >= _y ? _e : (T) 0.;
                };

                auto lambdaY = LAMBDA_TTT(_e, _x, _y) {
                    return _x <= _y ? _e : (T) 0.;
                };


                if (x->isSameShape(y)) {
                    // PWT case case

                    // X gradient
                    epsNext->applyTriplewiseLambda(x, y, lambdaX, gradX);

                    // Y gradient
                    epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY);

                } else if (y->isScalar()) {
                    T s = y->e<T>(0);
                    auto lambdaS = LAMBDA_TT(_e, _x, s) {
                        return _x >= s ? _e : (T) 0.;
                    };

                    // scalar case
                    auto tmp = epsNext->reduceNumber(reduce::Sum);
                    if (x <= y)
                        gradY->assign(tmp);
                    else
                        gradY->assign(0.0f);

                    epsNext->applyPairwiseLambda(x, lambdaS, gradX);
                } else {
                    // broadcast case

                    // in this case we want to boost our X and Y shapes to the size of FF pass output (or epsNext, which has the same shape)
                    auto preX = x->dup();
                    auto preY = y->dup();

                    auto targetShape = epsNext->getShapeAsVector();

                    preX->tileToShape(targetShape);
                    preY->tileToShape(targetShape);

                    epsNext->applyTriplewiseLambda(preX, preY, lambdaX, preX);
                    epsNext->applyTriplewiseLambda(preX, preY, lambdaY, preY);

                    auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo());
                    auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo());

                    if (axisX.size() > 0) {
                        auto sum = preX->reduceAlongDimension(reduce::Sum, axisX);
                        gradX->assign(sum);
                        delete sum;
                    } else
                        gradX->assign(preX);

                    if (axisY.size() > 0) {
                        auto sum = preY->reduceAlongDimension(reduce::Sum, axisY);
                        gradY->assign(sum);
                        delete sum;
                    } else
                        gradY->assign(preY);


                    delete preX;
                    delete preY;
                }
            }

            void maximumBPFunctor(nd4j::LaunchContext * context, NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY) {
                NDArray::prepareSpecialUse({gradX, gradY}, {x, y, epsNext});

                BUILD_SINGLE_SELECTOR(x->dataType(), maximumBPFunctor_, (x, y, epsNext, gradX, gradY), NUMERIC_TYPES);

                NDArray::registerSpecialUse({gradX, gradY}, {x, y, epsNext});
            }

        }
    }
}