58 lines
2.3 KiB
C++
Raw Normal View History

2021-02-01 21:31:45 +09:00
/* ******************************************************************************
*
2019-06-06 15:21:15 +03:00
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
2021-02-01 21:31:45 +09:00
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership.
2019-06-06 15:21:15 +03:00
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by raver119 on 28.10.2017.
//
#include "graph/execution/LogicReturn.h"
#include <helpers/EnumUtils.h>
#include <graph/Status.h>
2019-06-06 15:21:15 +03:00
namespace sd {
2019-06-06 15:21:15 +03:00
namespace graph {
Nd4jStatus LogicReturn::processNode(Graph *graph, Node *node) {
auto __variableSpace = graph->getVariableSpace();
for (int e = 0; e < node->input()->size(); e++) {
auto inputAddr = node->input()->at(e);
auto outputAddr = node->output()->at(e);
// FIXME!!
outputAddr.second = e;
if (Environment::getInstance().isDebugAndVerbose())
2019-06-06 15:21:15 +03:00
nd4j_debug("Return input: <%i, %i>; Return output: <%i, %i>\n", inputAddr.first, inputAddr.second, outputAddr.first, outputAddr.second);
auto varIn = __variableSpace->getVariable(inputAddr);
auto varOut = __variableSpace->getVariable(outputAddr);
nd4j_debug("Returning varType: [%s]\n", EnumUtils::_VariableTypeToString(varIn->variableType()));
// FIXME: this is obviously wrong, we should keep depth track for backprop here
varOut->getNDArray()->assign(varIn->getNDArray());
if (Environment::getInstance().isDebugAndVerbose())
2019-06-06 15:21:15 +03:00
nd4j_debug("In after: [%f]; Out after: [%f]\n", varIn->getNDArray()->meanNumber().e<float>(0), varOut->getNDArray()->meanNumber().e<float>(0));
}
return sd::Status::OK();
2019-06-06 15:21:15 +03:00
}
}
}