177 lines
5.8 KiB
C++
177 lines
5.8 KiB
C++
/*******************************************************************************
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
*
|
|
* This program and the accompanying materials are made available under the
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* License for the specific language governing permissions and limitations
|
|
* under the License.
|
|
*
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
******************************************************************************/
|
|
|
|
|
|
//
|
|
// Created by raver on 3/2/2019.
|
|
//
|
|
|
|
#ifndef DEV_TESTS_DECLARABLEBENCHMARK_H
|
|
#define DEV_TESTS_DECLARABLEBENCHMARK_H
|
|
|
|
#include <NDArray.h>
|
|
#include <Context.h>
|
|
#include <OpBenchmark.h>
|
|
#include <declarable/DeclarableOp.h>
|
|
#include <declarable/OpRegistrator.h>
|
|
#include <PointersManager.h>
|
|
|
|
namespace nd4j {
|
|
class ND4J_EXPORT DeclarableBenchmark : public OpBenchmark {
|
|
protected:
|
|
nd4j::ops::DeclarableOp *_op = nullptr;
|
|
nd4j::graph::Context *_context = nullptr;
|
|
public:
|
|
DeclarableBenchmark(nd4j::ops::DeclarableOp &op, std::string name = 0) : OpBenchmark() {
|
|
_op = &op; //ops::OpRegistrator::getInstance()->getOperation(op.getOpHash());
|
|
_testName = name;
|
|
}
|
|
|
|
void setContext(nd4j::graph::Context *ctx) {
|
|
_context = ctx;
|
|
}
|
|
|
|
std::string axis() override {
|
|
return "N/A";
|
|
}
|
|
|
|
std::string orders() override {
|
|
if(_context != nullptr && _context->isFastPath()){
|
|
std::vector<NDArray*>& ins = _context->fastpath_in();
|
|
std::string s;
|
|
for( int i=0; i<ins.size(); i++ ){
|
|
if(i > 0){
|
|
s += "/";
|
|
}
|
|
s += ShapeUtils::strideAsString(_context->getNDArray(i));
|
|
}
|
|
return s;
|
|
}
|
|
return "N/A";
|
|
}
|
|
|
|
std::string strides() override {
|
|
if (_context != nullptr && _context->isFastPath()) {
|
|
std::vector<NDArray*>& ins = _context->fastpath_in();
|
|
std::string s("");
|
|
for( int i=0; i<ins.size(); i++ ){
|
|
if(i > 0){
|
|
s += "/";
|
|
}
|
|
s += ShapeUtils::strideAsString(_context->getNDArray(i));
|
|
}
|
|
return s;
|
|
} else
|
|
return "N/A";
|
|
}
|
|
|
|
std::string inplace() override {
|
|
return "N/A";
|
|
}
|
|
|
|
void executeOnce() override {
|
|
PointersManager pm(LaunchContext::defaultContext(), "DeclarableBenchmark");
|
|
_op->execute(_context);
|
|
pm.synchronize();
|
|
}
|
|
|
|
OpBenchmark *clone() override {
|
|
return new DeclarableBenchmark(*_op, _testName);
|
|
}
|
|
|
|
std::string shape() override {
|
|
if (_context != nullptr && _context->isFastPath()) {
|
|
std::vector<NDArray*>& ins = _context->fastpath_in();
|
|
std::string s;
|
|
for( int i=0; i<ins.size(); i++ ){
|
|
if(i > 0){
|
|
s += "/";
|
|
}
|
|
s += ShapeUtils::shapeAsString(_context->getNDArray(i));
|
|
}
|
|
return s;
|
|
} else
|
|
return "N/A";
|
|
}
|
|
|
|
std::string dataType() override {
|
|
if (_context != nullptr && _context->isFastPath()){
|
|
std::vector<NDArray*>& ins = _context->fastpath_in();
|
|
std::string s;
|
|
for( int i=0; i<ins.size(); i++ ){
|
|
if(i > 0){
|
|
s += "/";
|
|
}
|
|
s += DataTypeUtils::asString(_context->getNDArray(i)->dataType());
|
|
}
|
|
return s;
|
|
} else
|
|
return "N/A";
|
|
}
|
|
|
|
std::string extra() override {
|
|
if(_context != nullptr){
|
|
std::vector<int>* iargs = _context->getIArguments();
|
|
std::vector<double>* targs = _context->getTArguments();
|
|
std::vector<bool>* bargs = _context->getBArguments();
|
|
std::string e;
|
|
bool any = false;
|
|
if(iargs != nullptr){
|
|
e += "iargs=[";
|
|
for( int i=0; i<iargs->size(); i++ ){
|
|
if(i > 0)
|
|
e += ",";
|
|
e += std::to_string(iargs->at(i));
|
|
}
|
|
e += "]";
|
|
any = true;
|
|
}
|
|
if(targs != nullptr){
|
|
if(any)
|
|
e += ",";
|
|
e += "targs=[";
|
|
for( int i=0; i<targs->size(); i++ ){
|
|
if(i > 0)
|
|
e += ",";
|
|
e += std::to_string(targs->at(i));
|
|
}
|
|
e += "]";
|
|
any = true;
|
|
}
|
|
if(bargs != nullptr){
|
|
if(any)
|
|
e += ",";
|
|
e += "bargs=[";
|
|
for( int i=0; i<bargs->size(); i++ ){
|
|
if(i > 0)
|
|
e += ",";
|
|
e += std::to_string(bargs->at(i));
|
|
}
|
|
e += "]";
|
|
}
|
|
return e;
|
|
}
|
|
return "N/A";
|
|
}
|
|
|
|
~DeclarableBenchmark() {
|
|
if (_context != nullptr)
|
|
delete _context;
|
|
}
|
|
};
|
|
}
|
|
|
|
#endif //DEV_TESTS_DECLARABLEBENCHMARKS_H
|