benchmarks fixes (#321)
* bunch of small fixes Signed-off-by: raver119 <raver119@gmail.com> * validation for legacy random op Signed-off-by: raver119 <raver119@gmail.com> * get rid of test Signed-off-by: raver119 <raver119@gmail.com>master
parent
e7a995e959
commit
4cf2afad2b
|
@ -61,10 +61,9 @@ namespace sd {
|
||||||
|
|
||||||
//nd4j_printf("Total Iterations: %i\n", totalIterations);
|
//nd4j_printf("Total Iterations: %i\n", totalIterations);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
for (int i = 0; i < totalIterations; i++) {
|
for (int i = 0; i < totalIterations; i++) {
|
||||||
shape::index2coords(i, xRank, xShape, xCoords);
|
if (xRank > 0)
|
||||||
|
shape::index2coords(i, xRank, xShape, xCoords);
|
||||||
|
|
||||||
Parameters params;
|
Parameters params;
|
||||||
for (int j = 0; j < xRank; j++) {
|
for (int j = 0; j < xRank; j++) {
|
||||||
|
|
|
@ -132,7 +132,7 @@ namespace sd {
|
||||||
// this method returns OpDescriptor, describing this Op instance
|
// this method returns OpDescriptor, describing this Op instance
|
||||||
OpDescriptor *getOpDescriptor();
|
OpDescriptor *getOpDescriptor();
|
||||||
|
|
||||||
Nd4jStatus validateDataTypes(Context& block);
|
virtual Nd4jStatus validateDataTypes(Context& block);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method should be available in each implemented Op, and should return Op output shape(s), for a given input shape(s)
|
* This method should be available in each implemented Op, and should return Op output shape(s), for a given input shape(s)
|
||||||
|
|
|
@ -46,6 +46,7 @@ namespace sd {
|
||||||
|
|
||||||
Nd4jStatus execute(Context* block) override;
|
Nd4jStatus execute(Context* block) override;
|
||||||
|
|
||||||
|
Nd4jStatus validateDataTypes(Context& block) override;
|
||||||
ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override;
|
ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override;
|
||||||
LegacyOp* clone() override;
|
LegacyOp* clone() override;
|
||||||
};
|
};
|
||||||
|
|
|
@ -266,9 +266,6 @@ namespace sd {
|
||||||
auto z = OUTPUT_VARIABLE(0); // NDArrayFactory::create_<T>('c', shape, block.getWorkspace());
|
auto z = OUTPUT_VARIABLE(0); // NDArrayFactory::create_<T>('c', shape, block.getWorkspace());
|
||||||
|
|
||||||
RandomLauncher::fillTruncatedNormal(block.launchContext(), block.randomGenerator(), z, mean, stdev);
|
RandomLauncher::fillTruncatedNormal(block.launchContext(), block.randomGenerator(), z, mean, stdev);
|
||||||
|
|
||||||
// FIXME: !!!
|
|
||||||
//OVERWRITE_RESULT(z);
|
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case sd::random::AlphaDropOut: {
|
case sd::random::AlphaDropOut: {
|
||||||
|
@ -421,6 +418,32 @@ namespace sd {
|
||||||
return arrayList;
|
return arrayList;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Nd4jStatus LegacyRandomOp::validateDataTypes(Context& block) {
|
||||||
|
if (block.isFastPath()) {
|
||||||
|
// in this case we'll roll through pre-defined outputs
|
||||||
|
auto fpo = block.fastpath_out();
|
||||||
|
for (auto v:fpo) {
|
||||||
|
if (v != nullptr) {
|
||||||
|
if (!v->isR())
|
||||||
|
return ND4J_STATUS_BAD_ARGUMENTS;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
std::pair<int,int> pair(block.nodeId(), 0);
|
||||||
|
if (block.getVariableSpace()->hasVariable(pair)) {
|
||||||
|
auto var = block.variable(pair);
|
||||||
|
if (!var->hasNDArray())
|
||||||
|
return ND4J_STATUS_BAD_ARGUMENTS;
|
||||||
|
|
||||||
|
auto arr = var->getNDArray();
|
||||||
|
if (!arr->isR())
|
||||||
|
return ND4J_STATUS_BAD_ARGUMENTS;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template Nd4jStatus LegacyRandomOp::validateAndExecute_, (Context&), FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template Nd4jStatus LegacyRandomOp::validateAndExecute_, (Context&), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1571,7 +1571,7 @@ namespace sd {
|
||||||
if(p.getIntParam("inplace") == 1){
|
if(p.getIntParam("inplace") == 1){
|
||||||
z.push_back(view);
|
z.push_back(view);
|
||||||
} else {
|
} else {
|
||||||
z.push_back(NDArrayFactory::create_<float>('c', {r,r}));
|
z.push_back(NDArrayFactory::create_<float>('c', {view->sizeAt(0),view->sizeAt(1)}));
|
||||||
}
|
}
|
||||||
delete arr;
|
delete arr;
|
||||||
};
|
};
|
||||||
|
|
|
@ -609,7 +609,7 @@ namespace sd {
|
||||||
|
|
||||||
nd4j_printf("Running LightBenchmarkSuite.pairwiseBenchmark [%s]\n", DataTypeUtils::asString(t).c_str());
|
nd4j_printf("Running LightBenchmarkSuite.pairwiseBenchmark [%s]\n", DataTypeUtils::asString(t).c_str());
|
||||||
BUILD_SINGLE_SELECTOR(t, result += pairwiseBenchmark, (), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(t, result += pairwiseBenchmark, (), LIBND4J_TYPES);
|
||||||
/*
|
|
||||||
nd4j_printf("Running LightBenchmarkSuite.reduceFullBenchmark [%s]\n", DataTypeUtils::asString(t).c_str());
|
nd4j_printf("Running LightBenchmarkSuite.reduceFullBenchmark [%s]\n", DataTypeUtils::asString(t).c_str());
|
||||||
BUILD_SINGLE_SELECTOR(t, result += reduceFullBenchmark, (), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(t, result += reduceFullBenchmark, (), LIBND4J_TYPES);
|
||||||
|
|
||||||
|
@ -627,13 +627,13 @@ namespace sd {
|
||||||
|
|
||||||
nd4j_printf("Running LightBenchmarkSuite.lstmBenchmark [%s]\n", DataTypeUtils::asString(t).c_str());
|
nd4j_printf("Running LightBenchmarkSuite.lstmBenchmark [%s]\n", DataTypeUtils::asString(t).c_str());
|
||||||
BUILD_SINGLE_SELECTOR(t, result += lstmBenchmark, (), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(t, result += lstmBenchmark, (), LIBND4J_TYPES);
|
||||||
*/
|
|
||||||
}
|
}
|
||||||
|
|
||||||
nd4j_printf("Running LightBenchmarkSuite.broadcast2d\n", "");
|
nd4j_printf("Running LightBenchmarkSuite.broadcast2d\n", "");
|
||||||
//result += broadcast2d();
|
result += broadcast2d();
|
||||||
nd4j_printf("Running LightBenchmarkSuite.mismatchedOrderAssign\n", "");
|
nd4j_printf("Running LightBenchmarkSuite.mismatchedOrderAssign\n", "");
|
||||||
//result += mismatchedOrderAssign();
|
result += mismatchedOrderAssign();
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
|
@ -57,6 +57,7 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
#ifdef RELEASE_BUILD
|
#ifdef RELEASE_BUILD
|
||||||
|
|
||||||
TEST_F(PerformanceTests, test_maxpooling2d_1) {
|
TEST_F(PerformanceTests, test_maxpooling2d_1) {
|
||||||
|
|
Loading…
Reference in New Issue