/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // Created by raver119 on 07.10.2017. // #include #include namespace nd4j { namespace ops { /////////////////////////////// template __registrator::__registrator() { auto ptr = new OpName(); OpRegistrator::getInstance()->registerOperation(ptr); } template __registratorSynonym::__registratorSynonym(const char *name, const char *oname) { auto ptr = reinterpret_cast(OpRegistrator::getInstance()->getOperation(oname)); if (ptr == nullptr) { std::string newName(name); std::string oldName(oname); OpRegistrator::getInstance()->updateMSVC(nd4j::ops::HashHelper::getInstance()->getLongHash(newName), oldName); return; } OpRegistrator::getInstance()->registerOperation(name, ptr); } /////////////////////////////// OpRegistrator* OpRegistrator::getInstance() { if (!_INSTANCE) _INSTANCE = new nd4j::ops::OpRegistrator(); return _INSTANCE; } void OpRegistrator::updateMSVC(Nd4jLong newHash, std::string& oldName) { std::pair pair(newHash, oldName); _msvc.insert(pair); } template std::string OpRegistrator::local_to_string(T value) { //create an output string stream std::ostringstream os ; //throw the value into the string stream os << value ; //convert the string stream into a string and return return os.str() ; } template <> std::string OpRegistrator::local_to_string(int value) { //create an output string stream std::ostringstream os ; //throw the value into the string stream os << value ; //convert the string stream into a string and return return os.str() ; } void OpRegistrator::sigIntHandler(int sig) { #ifndef _RELEASE delete OpRegistrator::getInstance(); #endif } void OpRegistrator::exitHandler() { #ifndef _RELEASE delete OpRegistrator::getInstance(); #endif } void OpRegistrator::sigSegVHandler(int sig) { #ifndef _RELEASE delete OpRegistrator::getInstance(); #endif } OpRegistrator::~OpRegistrator() { #ifndef _RELEASE _msvc.clear(); for (auto x : _uniqueD) delete x; _uniqueD.clear(); _declarablesD.clear(); _declarablesLD.clear(); #endif } const char * OpRegistrator::getAllCustomOperations() { _locker.lock(); if (!isInit) { for (std::map::iterator it=_declarablesD.begin(); it!=_declarablesD.end(); ++it) { std::string op = it->first + ":" + local_to_string(it->second->getOpDescriptor()->getHash()) + ":" + local_to_string(it->second->getOpDescriptor()->getNumberOfInputs()) + ":" + local_to_string(it->second->getOpDescriptor()->getNumberOfOutputs()) + ":" + local_to_string(it->second->getOpDescriptor()->allowsInplace()) + ":" + local_to_string(it->second->getOpDescriptor()->getNumberOfTArgs()) + ":" + local_to_string(it->second->getOpDescriptor()->getNumberOfIArgs()) + ":" + ";" ; _opsList += op; } isInit = true; } _locker.unlock(); return _opsList.c_str(); } bool OpRegistrator::registerOperation(const char* name, nd4j::ops::DeclarableOp* op) { std::string str(name); std::pair pair(str, op); _declarablesD.insert(pair); auto hash = nd4j::ops::HashHelper::getInstance()->getLongHash(str); std::pair pair2(hash, op); _declarablesLD.insert(pair2); return true; } /** * This method registers operation * * @param op */ bool OpRegistrator::registerOperation(nd4j::ops::DeclarableOp *op) { _uniqueD.emplace_back(op); return registerOperation(op->getOpName()->c_str(), op); } nd4j::ops::DeclarableOp* OpRegistrator::getOperation(const char *name) { std::string str(name); return getOperation(str); } /** * This method returns registered Op by name * * @param name * @return */ nd4j::ops::DeclarableOp *OpRegistrator::getOperation(Nd4jLong hash) { if (!_declarablesLD.count(hash)) { if (!_msvc.count(hash)) { nd4j_printf("Unknown D operation requested by hash: [%lld]\n", hash); return nullptr; } else { _locker.lock(); auto str = _msvc.at(hash); auto op = _declarablesD.at(str); auto oHash = op->getOpDescriptor()->getHash(); std::pair pair(oHash, op); _declarablesLD.insert(pair); _locker.unlock(); } } return _declarablesLD.at(hash); } nd4j::ops::DeclarableOp *OpRegistrator::getOperation(std::string& name) { if (!_declarablesD.count(name)) { nd4j_debug("Unknown operation requested: [%s]\n", name.c_str()); return nullptr; } return _declarablesD.at(name); } int OpRegistrator::numberOfOperations() { return (int) _declarablesLD.size(); } std::vector OpRegistrator::getAllHashes() { std::vector result; for (auto &v:_declarablesLD) { result.emplace_back(v.first); } return result; } nd4j::ops::OpRegistrator* nd4j::ops::OpRegistrator::_INSTANCE = 0; } }