/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author raver119@gmail.com // #include "../PlatformHelper.h" #include namespace sd { namespace ops { namespace platforms { PlatformHelper::PlatformHelper(const char *name, samediff::Engine engine) { // we just store name/hash of target operation _name = std::string(name); _hash = HashHelper::getInstance().getLongHash(_name); _engine = engine; } sd::NDArray* PlatformHelper::getNullifiedZ(graph::Context& block, int inputId) { auto result = getZ(block, inputId); if (result != nullptr && !block.isInplace()) result->nullify(); return result; } sd::NDArray* PlatformHelper::getZ(graph::Context &ctx, int inputId) { NDArray *z = nullptr; if (ctx.isFastPath()) { if (ctx.fastpath_out().size() <= inputId) { if (ctx.isInplace()) { z = ctx.fastpath_in()[inputId]; } else throw std::runtime_error("fastpath_out: unresolved output array"); } else { z = ctx.fastpath_out()[inputId]; } } else { std::pair pair(ctx.nodeId(), inputId); if (ctx.isInplace()) { z = ctx.variable(inputId)->getNDArray(); // hypothetically it's possible to have no variable. chances are low, but who knows. let's just create it for now if (!ctx.getVariableSpace()->hasVariable(pair)) { auto var = new graph::Variable(); ctx.getVariableSpace()->putVariable(pair, var); } // now we're saving input array as output array auto var = ctx.getVariableSpace()->getVariable(pair); var->markRemovable(false); var->setNDArray(z); } else if (!ctx.isInplace()) { auto var = ctx.variable(pair); if (var->getNDArray() != nullptr && var->getNDArray()->nonNull()) { z = var->getNDArray(); } else { nd4j_printf("Can't get Z variable for node_%i!\n", ctx.nodeId()); } } else { nd4j_printf("BOOM!\n", ""); throw std::runtime_error("Boom!"); } } return z; } samediff::Engine PlatformHelper::engine() { return _engine; } std::string PlatformHelper::name() { return _name; } Nd4jLong PlatformHelper::hash() { return _hash; } } } }