/* ******************************************************************************
 *
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 *  See the NOTICE file distributed with this work for additional
 *  information regarding copyright ownership.
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

//
// Created by raver119 on 30.01.18.
//

#include <graph/execution/LogicMerge.h>
#include <graph/Status.h>

namespace sd {
    namespace graph {
        Nd4jStatus LogicMerge::processNode(Graph *graph, Node *node) {
            // at merge node only one of inputs exist if that's just switch and other node isn't LogicNextItration
            auto __variableSpace = graph->getVariableSpace();
            auto __flowPath = __variableSpace->flowPath();

            // merge MUST have 2 inputs
            auto inputAddr0 = node->input()->at(0);
            auto inputAddr1 = node->input()->at(1);

            bool isWhile = false;

            // now we want to check if second input is NextIteration
            if (graph->hasNode(inputAddr1.first)) {
                auto secondNode = graph->nodeById(inputAddr1.first);

                // checking for NextIteration
                if (secondNode->opType() == OpType_LOGIC && secondNode->opNum() == 80L) {
                    isWhile = true;

                    // notifying NextIteration node for rewind index
                    secondNode->setRewindLayer(node->getLayer());
                    secondNode->setRewindNode(node->id());
                }

            }

            // FIXME: we don't need this check. Just last input should survive, IF it exists
            if (isWhile){

                if (node->getFrameId() >= 0)
                    __flowPath->markFrameActive(node->getFrameId(), true);

                bool hasVar = __variableSpace->hasVariable(inputAddr1);
                if ( hasVar && __flowPath->wasExecuted(inputAddr1.first)) {
                    nd4j_debug("Node_%i: propagating second input\n", node->id());
                    auto var = __variableSpace->getVariable(inputAddr1);

                    Variable *lvar = nullptr;
                    if (__variableSpace->hasVariable(node->id(), 0))
                        lvar = __variableSpace->getVariable(node->id(), 0);
                    else
                        lvar = new Variable(nullptr, node->getName()->c_str(), node->id(), 0);

//                    if (lvar->hasNDArray())
//                        delete lvar->getNDArray();

                    auto array = var->getNDArray();

                    //array->printIndexedBuffer("propagated");

                    lvar->setNDArray(array);
                    lvar->markReadOnly(true);

                    __flowPath->markExecuted(inputAddr1.first, false);


                } else {
                    nd4j_debug("Node_%i: propagating first input\n", node->id());
                    auto var = __variableSpace->getVariable(inputAddr0);

                    Variable *lvar = nullptr;
                    if (__variableSpace->hasVariable(node->id(), 0))
                        lvar = __variableSpace->getVariable(node->id(), 0);
                    else
                        lvar = new Variable(nullptr, node->getName()->c_str(), node->id(), 0);

//                    if (lvar->hasNDArray())
//                        delete lvar->getNDArray();

                    auto array = var->getNDArray();
                    lvar->setNDArray(array);
                    lvar->markReadOnly(true);


                }
            } else {

                // basically, first non-null variable is our target
                for (int e = 0; e < node->input()->size(); e++) {
                    auto inputAddr = node->input()->at(e);

                    if (__variableSpace->hasVariable(inputAddr)) {
                        auto var = __variableSpace->getVariable(inputAddr);
                        if (!var->hasNDArray() || !__flowPath->isNodeActive(inputAddr.first))
                            continue;

                        Variable *lvar = nullptr;
                        if (__variableSpace->hasVariable(node->id(), 0))
                            lvar = __variableSpace->getVariable(node->id(), 0);
                        else
                            lvar = new Variable(nullptr, node->getName()->c_str(), node->id(), 0);

                        if (lvar->hasNDArray())
                            delete lvar->getNDArray();

                        auto array = var->getNDArray();
                        lvar->setNDArray(array);
                        lvar->markReadOnly(true);
                        //lvar->markExternal(false);h

                        break;
                    }
                }
            }

            return Status::OK();
        }
    }
}