benchmarks fixes (#321)
* bunch of small fixes Signed-off-by: raver119 <raver119@gmail.com> * validation for legacy random op Signed-off-by: raver119 <raver119@gmail.com> * get rid of test Signed-off-by: raver119 <raver119@gmail.com>master
parent
e7a995e959
commit
4cf2afad2b
|
@ -61,9 +61,8 @@ namespace sd {
|
|||
|
||||
//nd4j_printf("Total Iterations: %i\n", totalIterations);
|
||||
|
||||
|
||||
|
||||
for (int i = 0; i < totalIterations; i++) {
|
||||
if (xRank > 0)
|
||||
shape::index2coords(i, xRank, xShape, xCoords);
|
||||
|
||||
Parameters params;
|
||||
|
|
|
@ -132,7 +132,7 @@ namespace sd {
|
|||
// this method returns OpDescriptor, describing this Op instance
|
||||
OpDescriptor *getOpDescriptor();
|
||||
|
||||
Nd4jStatus validateDataTypes(Context& block);
|
||||
virtual Nd4jStatus validateDataTypes(Context& block);
|
||||
|
||||
/**
|
||||
* This method should be available in each implemented Op, and should return Op output shape(s), for a given input shape(s)
|
||||
|
|
|
@ -46,6 +46,7 @@ namespace sd {
|
|||
|
||||
Nd4jStatus execute(Context* block) override;
|
||||
|
||||
Nd4jStatus validateDataTypes(Context& block) override;
|
||||
ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override;
|
||||
LegacyOp* clone() override;
|
||||
};
|
||||
|
|
|
@ -266,9 +266,6 @@ namespace sd {
|
|||
auto z = OUTPUT_VARIABLE(0); // NDArrayFactory::create_<T>('c', shape, block.getWorkspace());
|
||||
|
||||
RandomLauncher::fillTruncatedNormal(block.launchContext(), block.randomGenerator(), z, mean, stdev);
|
||||
|
||||
// FIXME: !!!
|
||||
//OVERWRITE_RESULT(z);
|
||||
}
|
||||
break;
|
||||
case sd::random::AlphaDropOut: {
|
||||
|
@ -421,6 +418,32 @@ namespace sd {
|
|||
return arrayList;
|
||||
}
|
||||
|
||||
Nd4jStatus LegacyRandomOp::validateDataTypes(Context& block) {
|
||||
if (block.isFastPath()) {
|
||||
// in this case we'll roll through pre-defined outputs
|
||||
auto fpo = block.fastpath_out();
|
||||
for (auto v:fpo) {
|
||||
if (v != nullptr) {
|
||||
if (!v->isR())
|
||||
return ND4J_STATUS_BAD_ARGUMENTS;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
std::pair<int,int> pair(block.nodeId(), 0);
|
||||
if (block.getVariableSpace()->hasVariable(pair)) {
|
||||
auto var = block.variable(pair);
|
||||
if (!var->hasNDArray())
|
||||
return ND4J_STATUS_BAD_ARGUMENTS;
|
||||
|
||||
auto arr = var->getNDArray();
|
||||
if (!arr->isR())
|
||||
return ND4J_STATUS_BAD_ARGUMENTS;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template Nd4jStatus LegacyRandomOp::validateAndExecute_, (Context&), FLOAT_TYPES);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1571,7 +1571,7 @@ namespace sd {
|
|||
if(p.getIntParam("inplace") == 1){
|
||||
z.push_back(view);
|
||||
} else {
|
||||
z.push_back(NDArrayFactory::create_<float>('c', {r,r}));
|
||||
z.push_back(NDArrayFactory::create_<float>('c', {view->sizeAt(0),view->sizeAt(1)}));
|
||||
}
|
||||
delete arr;
|
||||
};
|
||||
|
|
|
@ -609,7 +609,7 @@ namespace sd {
|
|||
|
||||
nd4j_printf("Running LightBenchmarkSuite.pairwiseBenchmark [%s]\n", DataTypeUtils::asString(t).c_str());
|
||||
BUILD_SINGLE_SELECTOR(t, result += pairwiseBenchmark, (), LIBND4J_TYPES);
|
||||
/*
|
||||
|
||||
nd4j_printf("Running LightBenchmarkSuite.reduceFullBenchmark [%s]\n", DataTypeUtils::asString(t).c_str());
|
||||
BUILD_SINGLE_SELECTOR(t, result += reduceFullBenchmark, (), LIBND4J_TYPES);
|
||||
|
||||
|
@ -627,13 +627,13 @@ namespace sd {
|
|||
|
||||
nd4j_printf("Running LightBenchmarkSuite.lstmBenchmark [%s]\n", DataTypeUtils::asString(t).c_str());
|
||||
BUILD_SINGLE_SELECTOR(t, result += lstmBenchmark, (), LIBND4J_TYPES);
|
||||
*/
|
||||
|
||||
}
|
||||
|
||||
nd4j_printf("Running LightBenchmarkSuite.broadcast2d\n", "");
|
||||
//result += broadcast2d();
|
||||
result += broadcast2d();
|
||||
nd4j_printf("Running LightBenchmarkSuite.mismatchedOrderAssign\n", "");
|
||||
//result += mismatchedOrderAssign();
|
||||
result += mismatchedOrderAssign();
|
||||
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -57,6 +57,7 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
|
||||
#ifdef RELEASE_BUILD
|
||||
|
||||
TEST_F(PerformanceTests, test_maxpooling2d_1) {
|
||||
|
|
Loading…
Reference in New Issue