159 lines
5.8 KiB
C++
159 lines
5.8 KiB
C++
|
/*******************************************************************************
|
||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||
|
*
|
||
|
* This program and the accompanying materials are made available under the
|
||
|
* terms of the Apache License, Version 2.0 which is available at
|
||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||
|
* License for the specific language governing permissions and limitations
|
||
|
* under the License.
|
||
|
*
|
||
|
* SPDX-License-Identifier: Apache-2.0
|
||
|
******************************************************************************/
|
||
|
|
||
|
//
|
||
|
// @author raver119@gmail.com
|
||
|
//
|
||
|
|
||
|
#include <ops/declarable/OpDescriptor.h>
|
||
|
#include <ops/declarable/DeclarableListOp.h>
|
||
|
#include <graph/Context.h>
|
||
|
#include <graph/Variable.h>
|
||
|
#include <graph/VariableSpace.h>
|
||
|
|
||
|
namespace nd4j {
|
||
|
namespace ops {
|
||
|
DeclarableListOp::DeclarableListOp(int numInputs, int numOutputs, const char* opName, int tArgs, int iArgs) : DeclarableOp::DeclarableOp(numInputs, numOutputs, opName, false, tArgs, iArgs) {
|
||
|
// This kind of operations work with sets: NDArrayList
|
||
|
this->getOpDescriptor()->setInputType(InputType_NUMERIC_SET);
|
||
|
}
|
||
|
/*
|
||
|
template <typename T>
|
||
|
void DeclarableListOp::execute(Block& block) {
|
||
|
//
|
||
|
}
|
||
|
*/
|
||
|
/**
|
||
|
* This method just outputs scalar buffer
|
||
|
*
|
||
|
* @tparam T
|
||
|
* @param inputShape
|
||
|
* @param block
|
||
|
* @return
|
||
|
*/
|
||
|
ShapeList* DeclarableListOp::calculateOutputShape(ShapeList* inputShape, nd4j::graph::Context& block) {
|
||
|
// TODO: ensure this method isn't ever called
|
||
|
|
||
|
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(block.dataType(), 'c', {1, 1});
|
||
|
return SHAPELIST(newShape);
|
||
|
}
|
||
|
|
||
|
nd4j::NDArray* nd4j::ops::DeclarableListOp::getZ(Context& block, int inputId) {
|
||
|
//nd4j_printf("wow\n","");
|
||
|
return nullptr;
|
||
|
}
|
||
|
|
||
|
void DeclarableListOp::setupResult(NDArray* array, Context& block) {
|
||
|
block.pushNDArrayToVariableSpace(block.getNodeId(), 0, array);
|
||
|
}
|
||
|
|
||
|
void DeclarableListOp::setupResultList(NDArrayList* arrayList, Context& block) {
|
||
|
block.pushNDArrayListToVariableSpace(block.getNodeId(), 0, arrayList);
|
||
|
}
|
||
|
|
||
|
ResultSet* DeclarableListOp::execute(NDArrayList* list, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs) {
|
||
|
std::vector<NDArray*> ins(inputs);
|
||
|
std::vector<double> tas(tArgs);
|
||
|
std::vector<int> ias(iArgs);
|
||
|
return this->execute(list, ins, tas, ias);
|
||
|
}
|
||
|
|
||
|
Nd4jStatus DeclarableListOp::execute(Context* block) {
|
||
|
if (block == nullptr)
|
||
|
throw std::invalid_argument("Block is NULL");
|
||
|
|
||
|
nd4j_debug("Executing list op: [%s]\n", this->getOpName()->c_str());
|
||
|
|
||
|
// ensure number of IArgs, TArgs match our expectations
|
||
|
REQUIRE_OK(this->validateArguments(*block));
|
||
|
|
||
|
// we shouldn't call for this in ListOp
|
||
|
//this->prepareOutputs(*block);
|
||
|
|
||
|
auto timeStart = std::chrono::system_clock::now();
|
||
|
|
||
|
Nd4jStatus status = this->validateAndExecute(*block);
|
||
|
|
||
|
auto timeEnd = std::chrono::system_clock::now();
|
||
|
auto outerTime = std::chrono::duration_cast<std::chrono::nanoseconds> (timeEnd - timeStart).count();
|
||
|
block->setInnerTime(outerTime);
|
||
|
|
||
|
return status;
|
||
|
}
|
||
|
|
||
|
ResultSet* DeclarableListOp::execute(NDArrayList* list, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs) {
|
||
|
VariableSpace varSpace;
|
||
|
int nodeId = 119;
|
||
|
|
||
|
// should be never used in practice, since in-graph NDArrayList should have id set
|
||
|
int cnt = -1;
|
||
|
std::vector<int> in;
|
||
|
if (list != nullptr) {
|
||
|
if (list->id().first == 0)
|
||
|
list->id().first = -1;
|
||
|
|
||
|
auto listVar = new Variable(nullptr, nullptr, -119, 0);
|
||
|
listVar->setNDArrayList(list);
|
||
|
varSpace.putVariable(-1, listVar);
|
||
|
in.push_back(-1);
|
||
|
cnt--;
|
||
|
}
|
||
|
|
||
|
|
||
|
for (auto v: inputs) {
|
||
|
auto var = new Variable(v);
|
||
|
var->markRemovable(false);
|
||
|
in.push_back(cnt);
|
||
|
varSpace.putVariable(cnt--, var);
|
||
|
}
|
||
|
|
||
|
Context block(1, &varSpace, false);
|
||
|
block.fillInputs(in);
|
||
|
|
||
|
for (int e = 0; e < tArgs.size(); e++)
|
||
|
block.getTArguments()->emplace_back(tArgs.at(e));
|
||
|
|
||
|
|
||
|
for (int e = 0; e < iArgs.size(); e++)
|
||
|
block.getIArguments()->emplace_back(iArgs.at(e));
|
||
|
|
||
|
|
||
|
Nd4jStatus result = this->validateAndExecute(block);
|
||
|
auto res = new ResultSet();
|
||
|
res->setStatus(result);
|
||
|
|
||
|
for (int e = 0; e < DataTypeUtils::max<int>(); e++) {
|
||
|
std::pair<int,int> pair(1, e);
|
||
|
if (varSpace.hasVariable(pair)) {
|
||
|
auto var = varSpace.getVariable(pair);
|
||
|
if (var->hasNDArray()) {
|
||
|
auto arr = var->getNDArray();
|
||
|
if (arr->isAttached()) {
|
||
|
auto d = arr->detach();
|
||
|
res->push_back(d);
|
||
|
} else {
|
||
|
var->markRemovable(false);
|
||
|
res->push_back(arr);
|
||
|
}
|
||
|
}
|
||
|
} else
|
||
|
break;
|
||
|
}
|
||
|
|
||
|
return res;
|
||
|
}
|
||
|
}
|
||
|
}
|