| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | /*******************************************************************************
 | 
					
						
							|  |  |  |  * Copyright (c) 2015-2018 Skymind, Inc. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * This program and the accompanying materials are made available under the | 
					
						
							|  |  |  |  * terms of the Apache License, Version 2.0 which is available at | 
					
						
							|  |  |  |  * https://www.apache.org/licenses/LICENSE-2.0.
 | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  |  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | 
					
						
							|  |  |  |  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | 
					
						
							|  |  |  |  * License for the specific language governing permissions and limitations | 
					
						
							|  |  |  |  * under the License. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * SPDX-License-Identifier: Apache-2.0 | 
					
						
							|  |  |  |  ******************************************************************************/ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | // Created by raver119 on 07.10.2017.
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include <ops/declarable/OpRegistrator.h>
 | 
					
						
							|  |  |  | #include <sstream>
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | namespace nd4j { | 
					
						
							|  |  |  |     namespace ops { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         ///////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         template <typename OpName> | 
					
						
							|  |  |  |         __registrator<OpName>::__registrator() { | 
					
						
							|  |  |  |             auto ptr = new OpName(); | 
					
						
							|  |  |  |             OpRegistrator::getInstance()->registerOperation(ptr); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         template <typename OpName> | 
					
						
							|  |  |  |         __registratorSynonym<OpName>::__registratorSynonym(const char *name, const char *oname) { | 
					
						
							|  |  |  |             auto ptr = reinterpret_cast<OpName *>(OpRegistrator::getInstance()->getOperation(oname)); | 
					
						
							|  |  |  |             if (ptr == nullptr) { | 
					
						
							|  |  |  |                 std::string newName(name); | 
					
						
							|  |  |  |                 std::string oldName(oname); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 OpRegistrator::getInstance()->updateMSVC(nd4j::ops::HashHelper::getInstance()->getLongHash(newName), oldName); | 
					
						
							|  |  |  |                 return; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             OpRegistrator::getInstance()->registerOperation(name, ptr); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         ///////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         OpRegistrator* OpRegistrator::getInstance() { | 
					
						
							|  |  |  |             if (!_INSTANCE) | 
					
						
							|  |  |  |                 _INSTANCE = new nd4j::ops::OpRegistrator(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             return _INSTANCE; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         void OpRegistrator::updateMSVC(Nd4jLong newHash, std::string& oldName) { | 
					
						
							|  |  |  |             std::pair<Nd4jLong, std::string> pair(newHash, oldName); | 
					
						
							|  |  |  |             _msvc.insert(pair); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         template <typename T> | 
					
						
							|  |  |  |         std::string OpRegistrator::local_to_string(T value) { | 
					
						
							|  |  |  |             //create an output string stream
 | 
					
						
							|  |  |  |             std::ostringstream os ; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             //throw the value into the string stream
 | 
					
						
							|  |  |  |             os << value ; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             //convert the string stream into a string and return
 | 
					
						
							|  |  |  |             return os.str() ; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         template <> | 
					
						
							|  |  |  |         std::string OpRegistrator::local_to_string(int value) { | 
					
						
							|  |  |  |             //create an output string stream
 | 
					
						
							|  |  |  |             std::ostringstream os ; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             //throw the value into the string stream
 | 
					
						
							|  |  |  |             os << value ; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             //convert the string stream into a string and return
 | 
					
						
							|  |  |  |             return os.str() ; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         void OpRegistrator::sigIntHandler(int sig) { | 
					
						
							|  |  |  | #ifndef _RELEASE
 | 
					
						
							|  |  |  |             delete OpRegistrator::getInstance(); | 
					
						
							|  |  |  | #endif
 | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         void OpRegistrator::exitHandler() { | 
					
						
							|  |  |  | #ifndef _RELEASE
 | 
					
						
							|  |  |  |             delete OpRegistrator::getInstance(); | 
					
						
							|  |  |  | #endif
 | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         void OpRegistrator::sigSegVHandler(int sig) { | 
					
						
							|  |  |  | #ifndef _RELEASE
 | 
					
						
							|  |  |  |             delete OpRegistrator::getInstance(); | 
					
						
							|  |  |  | #endif
 | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         OpRegistrator::~OpRegistrator() { | 
					
						
							|  |  |  | #ifndef _RELEASE
 | 
					
						
							|  |  |  |             _msvc.clear(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             for (auto x : _uniqueD) | 
					
						
							|  |  |  |                 delete x; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-09-11 21:50:28 +03:00
										 |  |  |             for (auto x: _uniqueH) | 
					
						
							|  |  |  |                 delete x; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |             _uniqueD.clear(); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-09-11 21:50:28 +03:00
										 |  |  |             _uniqueH.clear(); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |             _declarablesD.clear(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             _declarablesLD.clear(); | 
					
						
							|  |  |  | #endif
 | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         const char * OpRegistrator::getAllCustomOperations() { | 
					
						
							|  |  |  |             _locker.lock(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if (!isInit) { | 
					
						
							| 
									
										
										
										
											2020-02-24 06:51:01 +02:00
										 |  |  |                 for (MAP_IMPL<std::string, nd4j::ops::DeclarableOp*>::iterator it=_declarablesD.begin(); it!=_declarablesD.end(); ++it) { | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                     std::string op = it->first + ":" | 
					
						
							|  |  |  |                                      + local_to_string(it->second->getOpDescriptor()->getHash()) + ":" | 
					
						
							|  |  |  |                                      + local_to_string(it->second->getOpDescriptor()->getNumberOfInputs()) + ":" | 
					
						
							|  |  |  |                                      + local_to_string(it->second->getOpDescriptor()->getNumberOfOutputs()) + ":" | 
					
						
							|  |  |  |                                      + local_to_string(it->second->getOpDescriptor()->allowsInplace())  + ":" | 
					
						
							|  |  |  |                                      + local_to_string(it->second->getOpDescriptor()->getNumberOfTArgs())  + ":" | 
					
						
							|  |  |  |                                      + local_to_string(it->second->getOpDescriptor()->getNumberOfIArgs())  + ":" | 
					
						
							|  |  |  |                                      + ";" ; | 
					
						
							|  |  |  |                     _opsList += op; | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 isInit = true; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             _locker.unlock(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             return _opsList.c_str(); | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2019-09-11 21:50:28 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |         bool OpRegistrator::registerOperation(const char* name, nd4j::ops::DeclarableOp* op) { | 
					
						
							|  |  |  |             std::string str(name); | 
					
						
							|  |  |  |             std::pair<std::string, nd4j::ops::DeclarableOp*> pair(str, op); | 
					
						
							|  |  |  |             _declarablesD.insert(pair); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             auto hash = nd4j::ops::HashHelper::getInstance()->getLongHash(str); | 
					
						
							|  |  |  |             std::pair<Nd4jLong, nd4j::ops::DeclarableOp*> pair2(hash, op); | 
					
						
							|  |  |  |             _declarablesLD.insert(pair2); | 
					
						
							|  |  |  |             return true; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         /**
 | 
					
						
							|  |  |  |          * This method registers operation | 
					
						
							|  |  |  |          * | 
					
						
							|  |  |  |          * @param op | 
					
						
							|  |  |  |          */ | 
					
						
							|  |  |  |         bool OpRegistrator::registerOperation(nd4j::ops::DeclarableOp *op) { | 
					
						
							|  |  |  |             _uniqueD.emplace_back(op); | 
					
						
							|  |  |  |             return registerOperation(op->getOpName()->c_str(), op); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-09-11 21:50:28 +03:00
										 |  |  |         void OpRegistrator::registerHelper(nd4j::ops::platforms::PlatformHelper* op) { | 
					
						
							| 
									
										
										
										
											2020-01-20 21:32:46 +03:00
										 |  |  |             std::pair<Nd4jLong, samediff::Engine> p = {op->hash(), op->engine()}; | 
					
						
							|  |  |  |             if (_helpersLH.count(p) > 0) | 
					
						
							| 
									
										
										
										
											2019-09-11 21:50:28 +03:00
										 |  |  |                 throw std::runtime_error("Tried to double register PlatformHelper"); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             _uniqueH.emplace_back(op); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-01-20 21:32:46 +03:00
										 |  |  |             nd4j_debug("Adding helper for op \"%s\": [%lld - %i]\n", op->name().c_str(), op->hash(), (int) op->engine()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             std::pair<std::pair<std::string, samediff::Engine>, nd4j::ops::platforms::PlatformHelper*> pair({op->name(), op->engine()}, op); | 
					
						
							| 
									
										
										
										
											2019-09-11 21:50:28 +03:00
										 |  |  |             _helpersH.insert(pair); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-01-20 21:32:46 +03:00
										 |  |  |             std::pair<std::pair<Nd4jLong, samediff::Engine>, nd4j::ops::platforms::PlatformHelper*> pair2(p, op); | 
					
						
							| 
									
										
										
										
											2019-09-11 21:50:28 +03:00
										 |  |  |             _helpersLH.insert(pair2); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |         nd4j::ops::DeclarableOp* OpRegistrator::getOperation(const char *name) { | 
					
						
							|  |  |  |             std::string str(name); | 
					
						
							|  |  |  |             return getOperation(str); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         /**
 | 
					
						
							|  |  |  |          * This method returns registered Op by name | 
					
						
							|  |  |  |          * | 
					
						
							|  |  |  |          * @param name | 
					
						
							|  |  |  |          * @return | 
					
						
							|  |  |  |          */ | 
					
						
							|  |  |  |         nd4j::ops::DeclarableOp *OpRegistrator::getOperation(Nd4jLong hash) { | 
					
						
							|  |  |  |             if (!_declarablesLD.count(hash)) { | 
					
						
							|  |  |  |                 if (!_msvc.count(hash)) { | 
					
						
							|  |  |  |                     nd4j_printf("Unknown D operation requested by hash: [%lld]\n", hash); | 
					
						
							|  |  |  |                     return nullptr; | 
					
						
							|  |  |  |                 } else { | 
					
						
							|  |  |  |                     _locker.lock(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     auto str = _msvc.at(hash); | 
					
						
							|  |  |  |                     auto op = _declarablesD.at(str); | 
					
						
							|  |  |  |                     auto oHash = op->getOpDescriptor()->getHash(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     std::pair<Nd4jLong, nd4j::ops::DeclarableOp*> pair(oHash, op); | 
					
						
							|  |  |  |                     _declarablesLD.insert(pair); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     _locker.unlock(); | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             return _declarablesLD.at(hash); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         nd4j::ops::DeclarableOp *OpRegistrator::getOperation(std::string& name) { | 
					
						
							|  |  |  |             if (!_declarablesD.count(name)) { | 
					
						
							|  |  |  |                 nd4j_debug("Unknown operation requested: [%s]\n", name.c_str()); | 
					
						
							|  |  |  |                 return nullptr; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             return _declarablesD.at(name); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-01-20 21:32:46 +03:00
										 |  |  |         nd4j::ops::platforms::PlatformHelper* OpRegistrator::getPlatformHelper(Nd4jLong hash, samediff::Engine engine) { | 
					
						
							|  |  |  |             std::pair<Nd4jLong, samediff::Engine> p = {hash, engine}; | 
					
						
							|  |  |  |             if (_helpersLH.count(p) == 0) | 
					
						
							| 
									
										
										
										
											2019-09-11 21:50:28 +03:00
										 |  |  |                 throw std::runtime_error("Requested helper can't be found"); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-01-20 21:32:46 +03:00
										 |  |  |             return _helpersLH[p]; | 
					
						
							| 
									
										
										
										
											2019-09-11 21:50:28 +03:00
										 |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-01-20 21:32:46 +03:00
										 |  |  |         bool OpRegistrator::hasHelper(Nd4jLong hash, samediff::Engine engine) { | 
					
						
							|  |  |  |             std::pair<Nd4jLong, samediff::Engine> p = {hash, engine}; | 
					
						
							|  |  |  |             return _helpersLH.count(p) > 0; | 
					
						
							| 
									
										
										
										
											2019-09-11 21:50:28 +03:00
										 |  |  |         } | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |         int OpRegistrator::numberOfOperations() { | 
					
						
							|  |  |  |             return (int) _declarablesLD.size(); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         std::vector<Nd4jLong> OpRegistrator::getAllHashes() { | 
					
						
							|  |  |  |             std::vector<Nd4jLong> result; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             for (auto &v:_declarablesLD) { | 
					
						
							|  |  |  |                 result.emplace_back(v.first); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             return result; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         nd4j::ops::OpRegistrator* nd4j::ops::OpRegistrator::_INSTANCE = 0; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-02-24 06:51:01 +02:00
										 |  |  | namespace std { | 
					
						
							|  |  |  |     size_t hash<std::pair<Nd4jLong, samediff::Engine>>::operator()(const std::pair<Nd4jLong, samediff::Engine>& k) const { | 
					
						
							|  |  |  |         using std::hash; | 
					
						
							|  |  |  |         auto res = std::hash<Nd4jLong>()(k.first); | 
					
						
							|  |  |  |         res ^= std::hash<int>()((int) k.second)  + 0x9e3779b9 + (res << 6) + (res >> 2); | 
					
						
							|  |  |  |         return res; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     size_t hash<std::pair<std::string, samediff::Engine>>::operator()(const std::pair<std::string, samediff::Engine>& k) const { | 
					
						
							|  |  |  |         using std::hash; | 
					
						
							|  |  |  |         auto res = std::hash<std::string>()(k.first); | 
					
						
							|  |  |  |         res ^= std::hash<int>()((int) k.second)  + 0x9e3779b9 + (res << 6) + (res >> 2); | 
					
						
							|  |  |  |         return res; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 |