* initial commit Signed-off-by: raver119 <raver119@gmail.com> * one file Signed-off-by: raver119 <raver119@gmail.com> * few more includes Signed-off-by: raver119 <raver119@gmail.com> * m? Signed-off-by: raver119 <raver119@gmail.com> * const Signed-off-by: raver119 <raver119@gmail.com> * cudnn linkage in tests Signed-off-by: raver119 <raver119@gmail.com> * culibos Signed-off-by: raver119 <raver119@gmail.com> * static reminder Signed-off-by: raver119 <raver119@gmail.com> * platform engine tag Signed-off-by: raver119 <raver119@gmail.com> * HAVE_CUDNN moved to config.h.in Signed-off-by: raver119 <raver119@gmail.com> * include Signed-off-by: raver119 <raver119@gmail.com> * include Signed-off-by: raver119 <raver119@gmail.com> * skip cudnn handle creation if there's not cudnn Signed-off-by: raver119 <raver119@gmail.com> * meh Signed-off-by: raver119 <raver119@gmail.com> * target device in context Signed-off-by: raver119 <raver119@gmail.com> * platform engines Signed-off-by: raver119 <raver119@gmail.com> * platform engines Signed-off-by: raver119 <raver119@gmail.com> * allow multiple -h args Signed-off-by: raver119 <raver119@gmail.com> * allow multiple -h args Signed-off-by: raver119 <raver119@gmail.com> * move mkldnn out of CPU block Signed-off-by: raver119 <raver119@gmail.com> * link to mkldnn on cuda Signed-off-by: raver119 <raver119@gmail.com> * less prints Signed-off-by: raver119 <raver119@gmail.com> * minor tweaks Signed-off-by: raver119 <raver119@gmail.com> * next step Signed-off-by: raver119 <raver119@gmail.com> * conv2d NCHW draft Signed-off-by: raver119 <raver119@gmail.com> * conv2d biasAdd Signed-off-by: raver119 <raver119@gmail.com> * test for MKL/CUDNN combined use Signed-off-by: raver119 <raver119@gmail.com> * - provide additional code for conv2d ff based on cudnn api, not tested yet Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on conv2d helper based on using cudnn api Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing several cuda bugs which appeared after cudnn lib had been started to use Signed-off-by: Yurii <iuriish@yahoo.com> * - implementation of conv2d backprop op based on cudnn api Signed-off-by: Yurii <iuriish@yahoo.com> * - implementaion of conv3d and conv3d_bp ops based on cudnn api Signed-off-by: Yurii <iuriish@yahoo.com> * - bugs fixing in conv3d/conv3d_bp ops (cudnn in use) Signed-off-by: Yurii <iuriish@yahoo.com> * - implementation of depthwiseConv2d (ff/bp) op based on cudnn api Signed-off-by: Yurii <iuriish@yahoo.com> * - implementation of batchnorm ff op based on cudnn api Signed-off-by: Yurii <iuriish@yahoo.com> * - disable cudnn batchnorm temporary Signed-off-by: Yurii <iuriish@yahoo.com> * - add minor change in cmake Signed-off-by: Yurii <iuriish@yahoo.com> * engine for depthwise mkldnn Signed-off-by: raver119 <raver119@gmail.com> * couple of includes Signed-off-by: raver119 <raver119@gmail.com> * - provide permutation to cudnn batchnorm ff when format is NHWC Signed-off-by: Yurii <iuriish@yahoo.com> * lgamma fix Signed-off-by: raver119 <raver119@gmail.com> * - eliminate memory leak in two tests Signed-off-by: Yurii <iuriish@yahoo.com> Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
264 lines
8.6 KiB
C++
264 lines
8.6 KiB
C++
/*******************************************************************************
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
*
|
|
* This program and the accompanying materials are made available under the
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* License for the specific language governing permissions and limitations
|
|
* under the License.
|
|
*
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
******************************************************************************/
|
|
|
|
//
|
|
// Created by raver119 on 07.10.2017.
|
|
//
|
|
|
|
|
|
|
|
#include <ops/declarable/OpRegistrator.h>
|
|
#include <sstream>
|
|
|
|
namespace nd4j {
|
|
namespace ops {
|
|
|
|
///////////////////////////////
|
|
|
|
template <typename OpName>
|
|
__registrator<OpName>::__registrator() {
|
|
auto ptr = new OpName();
|
|
OpRegistrator::getInstance()->registerOperation(ptr);
|
|
}
|
|
|
|
|
|
template <typename OpName>
|
|
__registratorSynonym<OpName>::__registratorSynonym(const char *name, const char *oname) {
|
|
auto ptr = reinterpret_cast<OpName *>(OpRegistrator::getInstance()->getOperation(oname));
|
|
if (ptr == nullptr) {
|
|
std::string newName(name);
|
|
std::string oldName(oname);
|
|
|
|
OpRegistrator::getInstance()->updateMSVC(nd4j::ops::HashHelper::getInstance()->getLongHash(newName), oldName);
|
|
return;
|
|
}
|
|
OpRegistrator::getInstance()->registerOperation(name, ptr);
|
|
}
|
|
|
|
///////////////////////////////
|
|
|
|
|
|
OpRegistrator* OpRegistrator::getInstance() {
|
|
if (!_INSTANCE)
|
|
_INSTANCE = new nd4j::ops::OpRegistrator();
|
|
|
|
return _INSTANCE;
|
|
}
|
|
|
|
|
|
void OpRegistrator::updateMSVC(Nd4jLong newHash, std::string& oldName) {
|
|
std::pair<Nd4jLong, std::string> pair(newHash, oldName);
|
|
_msvc.insert(pair);
|
|
}
|
|
|
|
template <typename T>
|
|
std::string OpRegistrator::local_to_string(T value) {
|
|
//create an output string stream
|
|
std::ostringstream os ;
|
|
|
|
//throw the value into the string stream
|
|
os << value ;
|
|
|
|
//convert the string stream into a string and return
|
|
return os.str() ;
|
|
}
|
|
|
|
template <>
|
|
std::string OpRegistrator::local_to_string(int value) {
|
|
//create an output string stream
|
|
std::ostringstream os ;
|
|
|
|
//throw the value into the string stream
|
|
os << value ;
|
|
|
|
//convert the string stream into a string and return
|
|
return os.str() ;
|
|
}
|
|
|
|
void OpRegistrator::sigIntHandler(int sig) {
|
|
#ifndef _RELEASE
|
|
delete OpRegistrator::getInstance();
|
|
#endif
|
|
}
|
|
|
|
void OpRegistrator::exitHandler() {
|
|
#ifndef _RELEASE
|
|
delete OpRegistrator::getInstance();
|
|
#endif
|
|
}
|
|
|
|
void OpRegistrator::sigSegVHandler(int sig) {
|
|
#ifndef _RELEASE
|
|
delete OpRegistrator::getInstance();
|
|
#endif
|
|
}
|
|
|
|
OpRegistrator::~OpRegistrator() {
|
|
#ifndef _RELEASE
|
|
_msvc.clear();
|
|
|
|
for (auto x : _uniqueD)
|
|
delete x;
|
|
|
|
for (auto x: _uniqueH)
|
|
delete x;
|
|
|
|
_uniqueD.clear();
|
|
|
|
_uniqueH.clear();
|
|
|
|
_declarablesD.clear();
|
|
|
|
_declarablesLD.clear();
|
|
#endif
|
|
}
|
|
|
|
const char * OpRegistrator::getAllCustomOperations() {
|
|
_locker.lock();
|
|
|
|
if (!isInit) {
|
|
for (std::map<std::string, nd4j::ops::DeclarableOp*>::iterator it=_declarablesD.begin(); it!=_declarablesD.end(); ++it) {
|
|
std::string op = it->first + ":"
|
|
+ local_to_string(it->second->getOpDescriptor()->getHash()) + ":"
|
|
+ local_to_string(it->second->getOpDescriptor()->getNumberOfInputs()) + ":"
|
|
+ local_to_string(it->second->getOpDescriptor()->getNumberOfOutputs()) + ":"
|
|
+ local_to_string(it->second->getOpDescriptor()->allowsInplace()) + ":"
|
|
+ local_to_string(it->second->getOpDescriptor()->getNumberOfTArgs()) + ":"
|
|
+ local_to_string(it->second->getOpDescriptor()->getNumberOfIArgs()) + ":"
|
|
+ ";" ;
|
|
_opsList += op;
|
|
}
|
|
|
|
isInit = true;
|
|
}
|
|
|
|
_locker.unlock();
|
|
|
|
return _opsList.c_str();
|
|
}
|
|
|
|
|
|
bool OpRegistrator::registerOperation(const char* name, nd4j::ops::DeclarableOp* op) {
|
|
std::string str(name);
|
|
std::pair<std::string, nd4j::ops::DeclarableOp*> pair(str, op);
|
|
_declarablesD.insert(pair);
|
|
|
|
auto hash = nd4j::ops::HashHelper::getInstance()->getLongHash(str);
|
|
std::pair<Nd4jLong, nd4j::ops::DeclarableOp*> pair2(hash, op);
|
|
_declarablesLD.insert(pair2);
|
|
return true;
|
|
}
|
|
|
|
/**
|
|
* This method registers operation
|
|
*
|
|
* @param op
|
|
*/
|
|
bool OpRegistrator::registerOperation(nd4j::ops::DeclarableOp *op) {
|
|
_uniqueD.emplace_back(op);
|
|
return registerOperation(op->getOpName()->c_str(), op);
|
|
}
|
|
|
|
void OpRegistrator::registerHelper(nd4j::ops::platforms::PlatformHelper* op) {
|
|
std::pair<Nd4jLong, samediff::Engine> p = {op->hash(), op->engine()};
|
|
if (_helpersLH.count(p) > 0)
|
|
throw std::runtime_error("Tried to double register PlatformHelper");
|
|
|
|
_uniqueH.emplace_back(op);
|
|
|
|
nd4j_debug("Adding helper for op \"%s\": [%lld - %i]\n", op->name().c_str(), op->hash(), (int) op->engine());
|
|
|
|
std::pair<std::pair<std::string, samediff::Engine>, nd4j::ops::platforms::PlatformHelper*> pair({op->name(), op->engine()}, op);
|
|
_helpersH.insert(pair);
|
|
|
|
std::pair<std::pair<Nd4jLong, samediff::Engine>, nd4j::ops::platforms::PlatformHelper*> pair2(p, op);
|
|
_helpersLH.insert(pair2);
|
|
}
|
|
|
|
nd4j::ops::DeclarableOp* OpRegistrator::getOperation(const char *name) {
|
|
std::string str(name);
|
|
return getOperation(str);
|
|
}
|
|
|
|
/**
|
|
* This method returns registered Op by name
|
|
*
|
|
* @param name
|
|
* @return
|
|
*/
|
|
nd4j::ops::DeclarableOp *OpRegistrator::getOperation(Nd4jLong hash) {
|
|
if (!_declarablesLD.count(hash)) {
|
|
if (!_msvc.count(hash)) {
|
|
nd4j_printf("Unknown D operation requested by hash: [%lld]\n", hash);
|
|
return nullptr;
|
|
} else {
|
|
_locker.lock();
|
|
|
|
auto str = _msvc.at(hash);
|
|
auto op = _declarablesD.at(str);
|
|
auto oHash = op->getOpDescriptor()->getHash();
|
|
|
|
std::pair<Nd4jLong, nd4j::ops::DeclarableOp*> pair(oHash, op);
|
|
_declarablesLD.insert(pair);
|
|
|
|
_locker.unlock();
|
|
}
|
|
}
|
|
|
|
return _declarablesLD.at(hash);
|
|
}
|
|
|
|
nd4j::ops::DeclarableOp *OpRegistrator::getOperation(std::string& name) {
|
|
if (!_declarablesD.count(name)) {
|
|
nd4j_debug("Unknown operation requested: [%s]\n", name.c_str());
|
|
return nullptr;
|
|
}
|
|
|
|
return _declarablesD.at(name);
|
|
}
|
|
|
|
nd4j::ops::platforms::PlatformHelper* OpRegistrator::getPlatformHelper(Nd4jLong hash, samediff::Engine engine) {
|
|
std::pair<Nd4jLong, samediff::Engine> p = {hash, engine};
|
|
if (_helpersLH.count(p) == 0)
|
|
throw std::runtime_error("Requested helper can't be found");
|
|
|
|
return _helpersLH[p];
|
|
}
|
|
|
|
bool OpRegistrator::hasHelper(Nd4jLong hash, samediff::Engine engine) {
|
|
std::pair<Nd4jLong, samediff::Engine> p = {hash, engine};
|
|
return _helpersLH.count(p) > 0;
|
|
}
|
|
|
|
int OpRegistrator::numberOfOperations() {
|
|
return (int) _declarablesLD.size();
|
|
}
|
|
|
|
std::vector<Nd4jLong> OpRegistrator::getAllHashes() {
|
|
std::vector<Nd4jLong> result;
|
|
|
|
for (auto &v:_declarablesLD) {
|
|
result.emplace_back(v.first);
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
nd4j::ops::OpRegistrator* nd4j::ops::OpRegistrator::_INSTANCE = 0;
|
|
}
|
|
}
|
|
|