benchmarks fixes (#321)

* bunch of small fixes

Signed-off-by: raver119 <raver119@gmail.com>

* validation for legacy random op

Signed-off-by: raver119 <raver119@gmail.com>

* get rid of test

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2020-03-16 10:31:06 +03:00 committed by GitHub
parent e7a995e959
commit 4cf2afad2b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 36 additions and 12 deletions

View File

@ -61,10 +61,9 @@ namespace sd {
//nd4j_printf("Total Iterations: %i\n", totalIterations);
for (int i = 0; i < totalIterations; i++) {
shape::index2coords(i, xRank, xShape, xCoords);
if (xRank > 0)
shape::index2coords(i, xRank, xShape, xCoords);
Parameters params;
for (int j = 0; j < xRank; j++) {

View File

@ -132,7 +132,7 @@ namespace sd {
// this method returns OpDescriptor, describing this Op instance
OpDescriptor *getOpDescriptor();
Nd4jStatus validateDataTypes(Context& block);
virtual Nd4jStatus validateDataTypes(Context& block);
/**
* This method should be available in each implemented Op, and should return Op output shape(s), for a given input shape(s)

View File

@ -46,6 +46,7 @@ namespace sd {
Nd4jStatus execute(Context* block) override;
Nd4jStatus validateDataTypes(Context& block) override;
ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override;
LegacyOp* clone() override;
};

View File

@ -266,9 +266,6 @@ namespace sd {
auto z = OUTPUT_VARIABLE(0); // NDArrayFactory::create_<T>('c', shape, block.getWorkspace());
RandomLauncher::fillTruncatedNormal(block.launchContext(), block.randomGenerator(), z, mean, stdev);
// FIXME: !!!
//OVERWRITE_RESULT(z);
}
break;
case sd::random::AlphaDropOut: {
@ -421,6 +418,32 @@ namespace sd {
return arrayList;
}
Nd4jStatus LegacyRandomOp::validateDataTypes(Context& block) {
if (block.isFastPath()) {
// in this case we'll roll through pre-defined outputs
auto fpo = block.fastpath_out();
for (auto v:fpo) {
if (v != nullptr) {
if (!v->isR())
return ND4J_STATUS_BAD_ARGUMENTS;
}
}
} else {
std::pair<int,int> pair(block.nodeId(), 0);
if (block.getVariableSpace()->hasVariable(pair)) {
auto var = block.variable(pair);
if (!var->hasNDArray())
return ND4J_STATUS_BAD_ARGUMENTS;
auto arr = var->getNDArray();
if (!arr->isR())
return ND4J_STATUS_BAD_ARGUMENTS;
}
}
return Status::OK();
}
BUILD_SINGLE_TEMPLATE(template Nd4jStatus LegacyRandomOp::validateAndExecute_, (Context&), FLOAT_TYPES);
}
}

View File

@ -1571,7 +1571,7 @@ namespace sd {
if(p.getIntParam("inplace") == 1){
z.push_back(view);
} else {
z.push_back(NDArrayFactory::create_<float>('c', {r,r}));
z.push_back(NDArrayFactory::create_<float>('c', {view->sizeAt(0),view->sizeAt(1)}));
}
delete arr;
};

View File

@ -609,7 +609,7 @@ namespace sd {
nd4j_printf("Running LightBenchmarkSuite.pairwiseBenchmark [%s]\n", DataTypeUtils::asString(t).c_str());
BUILD_SINGLE_SELECTOR(t, result += pairwiseBenchmark, (), LIBND4J_TYPES);
/*
nd4j_printf("Running LightBenchmarkSuite.reduceFullBenchmark [%s]\n", DataTypeUtils::asString(t).c_str());
BUILD_SINGLE_SELECTOR(t, result += reduceFullBenchmark, (), LIBND4J_TYPES);
@ -627,13 +627,13 @@ namespace sd {
nd4j_printf("Running LightBenchmarkSuite.lstmBenchmark [%s]\n", DataTypeUtils::asString(t).c_str());
BUILD_SINGLE_SELECTOR(t, result += lstmBenchmark, (), LIBND4J_TYPES);
*/
}
nd4j_printf("Running LightBenchmarkSuite.broadcast2d\n", "");
//result += broadcast2d();
result += broadcast2d();
nd4j_printf("Running LightBenchmarkSuite.mismatchedOrderAssign\n", "");
//result += mismatchedOrderAssign();
result += mismatchedOrderAssign();
return result;
}

View File

@ -57,6 +57,7 @@ public:
}
};
#ifdef RELEASE_BUILD
TEST_F(PerformanceTests, test_maxpooling2d_1) {