Improve ResultSet usage in libnd4j (#281)

* libnd4j profiling DeclarableOp and Tests by replacing return ResultSet pointer by instance

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j profiling semantic change in tests cases

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j some corrections to make new ResultSet semantic works, fixed one test

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j more tests fixes

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* - correct copy and move assignment operators of ResultSet class

Signed-off-by: Yurii <iuriish@yahoo.com>

Co-authored-by: Yurii <iuriish@yahoo.com>
Co-authored-by: raver119 <raver119@gmail.com>
master
Oleh 2020-03-10 06:42:50 +02:00 committed by GitHub
parent 57210b936c
commit c3223dbc7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
55 changed files with 6726 additions and 7658 deletions

View File

@ -40,6 +40,8 @@ namespace sd {
Nd4jStatus _status = ND4J_STATUS_OK; Nd4jStatus _status = ND4J_STATUS_OK;
bool _removable = true; bool _removable = true;
void delContent();
public: public:
explicit ResultSet(); explicit ResultSet();

View File

@ -160,9 +160,7 @@ namespace sd {
auto result = op.evaluate(inputs); auto result = op.evaluate(inputs);
auto array = new NDArray(result->at(0)->dup()); auto array = new NDArray(result.at(0)->dup());
delete result;
return array; return array;
} }

View File

@ -83,9 +83,10 @@ namespace sd {
if (this == &other) if (this == &other)
return *this; return *this;
this->~ResultSet(); delContent();
_content = std::move(other._content); _content = std::move(other._content);
_status = other._status; _status = other._status;
_removable = other._removable; _removable = other._removable;
other._removable = false; other._removable = false;
@ -98,10 +99,10 @@ namespace sd {
if (this == &other) if (this == &other)
return *this; return *this;
this->~ResultSet(); delContent();
for (const auto v : other._content) for (const auto v : other._content)
_content.emplace_back(v); _content.push_back(v);
_status = other._status; _status = other._status;
_removable = false; _removable = false;
@ -109,13 +110,17 @@ namespace sd {
return *this; return *this;
} }
void ResultSet::delContent() {
ResultSet::~ResultSet() {
if (_removable) if (_removable)
for (auto v : _content) for (auto v : _content)
delete v; delete v;
} }
ResultSet::~ResultSet() {
delContent();
}
void ResultSet::setNonRemovable() { void ResultSet::setNonRemovable() {
_removable = false; _removable = false;
} }

View File

@ -60,7 +60,7 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
fillGradArrays(loss, std::vector<NDArray*>(&inArrsBP[numInArrsFF], &inArrsBP[numInArrsFF + numInGradArrsBP])); fillGradArrays(loss, std::vector<NDArray*>(&inArrsBP[numInArrsFF], &inArrsBP[numInArrsFF + numInGradArrsBP]));
// back prop pass // back prop pass
ResultSet* outArrsBP = opBP.execute(argsHolderBP); // number of output arrays in back prop = numInArrsFF; ResultSet outArrsBP = opBP.execute(argsHolderBP); // number of output arrays in back prop = numInArrsFF;
NDArray tmpScalar(sd::DataType::DOUBLE, inArrsFF[0]->getContext()); // scalar = 0 NDArray tmpScalar(sd::DataType::DOUBLE, inArrsFF[0]->getContext()); // scalar = 0
@ -78,18 +78,17 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
// add epsilon, feed forward // add epsilon, feed forward
inArrsFF[i]->p<double>(j, orig + EPSILON); inArrsFF[i]->p<double>(j, orig + EPSILON);
ResultSet* outArrsFF = opFF.execute(argsHolderFF); ResultSet outArrsFF = opFF.execute(argsHolderFF);
int numOutArrs = outArrsFF->size(); int numOutArrs = outArrsFF.size();
double scorePlus = 0.; double scorePlus = 0.;
for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays
if(loss == SUM) if(loss == SUM)
outArrsFF->at(k)->reduceNumber(reduce::Sum, tmpScalar); outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar);
else else
outArrsFF->at(k)->reduceNumber(reduce::Mean, tmpScalar); outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar);
scorePlus += tmpScalar.e<double>(0); scorePlus += tmpScalar.e<double>(0);
} }
delete outArrsFF;
// subtract epsilon, feed forward // subtract epsilon, feed forward
inArrsFF[i]->p<double>(j, orig - EPSILON); inArrsFF[i]->p<double>(j, orig - EPSILON);
@ -98,12 +97,11 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays
if(loss == SUM) if(loss == SUM)
outArrsFF->at(k)->reduceNumber(reduce::Sum, tmpScalar); outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar);
else else
outArrsFF->at(k)->reduceNumber(reduce::Mean, tmpScalar); outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar);
scoreMinus += tmpScalar.e<double>(0); scoreMinus += tmpScalar.e<double>(0);
} }
delete outArrsFF;
// restore initial element value // restore initial element value
inArrsFF[i]->p<double>(j, orig); inArrsFF[i]->p<double>(j, orig);
@ -116,7 +114,7 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
} }
// get analytical gradient // get analytical gradient
const double analyticGrad = outArrsBP->at(i)->e<double>(j); const double analyticGrad = outArrsBP.at(i)->e<double>(j);
if(std::isnan(analyticGrad) || std::isinf(analyticGrad)) { if(std::isnan(analyticGrad) || std::isinf(analyticGrad)) {
printf("GradCheck::checkGrad: got wrong value for analytical gradient for input array # %i and its element at position %lld ! \n", i, j); printf("GradCheck::checkGrad: got wrong value for analytical gradient for input array # %i and its element at position %lld ! \n", i, j);
throw std::runtime_error(""); throw std::runtime_error("");
@ -138,13 +136,11 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
continue; continue;
printf("numericalGrad = %f, analyticGrad = %f \n", numericalGrad, analyticGrad); printf("numericalGrad = %f, analyticGrad = %f \n", numericalGrad, analyticGrad);
printf("GradCheck::checkGrad: got RELERROR = %f > MAXRELERROR(%f) for input array # %i and its element at position %lld ! \n", relError, MAXRELERR, i, j); printf("GradCheck::checkGrad: got RELERROR = %f > MAXRELERROR(%f) for input array # %i and its element at position %lld ! \n", relError, MAXRELERR, i, j);
delete outArrsBP;
return false; return false;
} }
} }
} }
delete outArrsBP;
return true; return true;
} }

View File

@ -45,8 +45,8 @@ namespace sd {
Nd4jStatus execute(Context* block) override; Nd4jStatus execute(Context* block) override;
ResultSet* execute(NDArrayList* list, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs); ResultSet execute(NDArrayList* list, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs);
ResultSet* execute(NDArrayList* list, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs); ResultSet execute(NDArrayList* list, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs);
ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override; ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override;
}; };

View File

@ -176,17 +176,17 @@ namespace sd {
Nd4jStatus execute(const std::vector<NDArray*> &inputs, const std::vector<NDArray*> &outputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs = std::vector<bool>(), const std::vector<sd::DataType> &dArgs = std::vector<sd::DataType>(), bool isInplace = false); Nd4jStatus execute(const std::vector<NDArray*> &inputs, const std::vector<NDArray*> &outputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs = std::vector<bool>(), const std::vector<sd::DataType> &dArgs = std::vector<sd::DataType>(), bool isInplace = false);
sd::ResultSet evaluate(const std::vector<NDArray*> &inputs);
sd::ResultSet* evaluate(const std::vector<NDArray*> &inputs);
template <class T, typename = std::enable_if<DataTypeUtils::scalarTypesForExecution<T>::value>> template <class T, typename = std::enable_if<DataTypeUtils::scalarTypesForExecution<T>::value>>
sd::ResultSet* evaluate(const std::vector<NDArray*> &inputs, std::initializer_list<T> args); sd::ResultSet evaluate(const std::vector<NDArray*> &inputs, std::initializer_list<T> args);
sd::ResultSet* evaluate(const std::vector<NDArray*> &inputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs = std::vector<bool>(), const std::vector<sd::DataType> &dArgs = std::vector<sd::DataType>(), bool isInplace = false); sd::ResultSet evaluate(const std::vector<NDArray*> &inputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs = std::vector<bool>(), const std::vector<sd::DataType> &dArgs = std::vector<sd::DataType>(), bool isInplace = false);
Nd4jStatus execute(sd::graph::RandomGenerator& rng, const std::vector<NDArray*>& inputs, const std::vector<NDArray*>& outputs, const std::vector<double>& tArgs, const std::vector<Nd4jLong>& iArgs, const std::vector<bool>& bArgs, const std::vector<sd::DataType> &dArgs = std::vector<sd::DataType>(), bool isInplace = false, sd::DataType type = sd::DataType::FLOAT32); Nd4jStatus execute(sd::graph::RandomGenerator& rng, const std::vector<NDArray*>& inputs, const std::vector<NDArray*>& outputs, const std::vector<double>& tArgs, const std::vector<Nd4jLong>& iArgs, const std::vector<bool>& bArgs, const std::vector<sd::DataType> &dArgs = std::vector<sd::DataType>(), bool isInplace = false, sd::DataType type = sd::DataType::FLOAT32);
sd::ResultSet* execute(const sd::OpArgsHolder& holder, bool isInplace = false); sd::ResultSet execute(const sd::OpArgsHolder& holder, bool isInplace = false);
// There methods provide various validation options // There methods provide various validation options
Nd4jStatus validateNonEmptyInput(Context& block); Nd4jStatus validateNonEmptyInput(Context& block);

View File

@ -41,8 +41,9 @@ namespace sd {
template <typename T> template <typename T>
Nd4jStatus validateAndExecute_(Context &block); Nd4jStatus validateAndExecute_(Context &block);
sd::ResultSet* execute(sd::graph::RandomGenerator& rng, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs, bool isInplace = false); sd::ResultSet execute(sd::graph::RandomGenerator& rng, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs, bool isInplace = false);
sd::ResultSet* execute(sd::graph::RandomGenerator& rng, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs, bool isInplace = false); sd::ResultSet execute(sd::graph::RandomGenerator& rng, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs, bool isInplace = false);
Nd4jStatus execute(Context* block) override; Nd4jStatus execute(Context* block) override;
ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override; ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override;

View File

@ -74,10 +74,10 @@ namespace sd {
// at first step we build fwd activation // at first step we build fwd activation
sd::ops::crelu op; sd::ops::crelu op;
auto tmpResult = op.evaluate({input}); auto tmpResult = op.evaluate({input});
if (tmpResult->status() != ND4J_STATUS_OK) if (tmpResult.status() != ND4J_STATUS_OK)
return tmpResult->status(); return tmpResult.status();
auto actv = tmpResult->at(0); auto actv = tmpResult.at(0);
// now we do RELU backward pass // now we do RELU backward pass
//actv->applyPairwiseTransform(pairwise::RELUDerivativeE, *epsilon, nullptr); //actv->applyPairwiseTransform(pairwise::RELUDerivativeE, *epsilon, nullptr);
@ -85,17 +85,15 @@ namespace sd {
// now we split updated array into 2 chunks along last dimension // now we split updated array into 2 chunks along last dimension
sd::ops::concat_bp opc; sd::ops::concat_bp opc;
auto dec = opc.evaluate({input, input, actv}, {-1}); auto dec = opc.evaluate({input, input, actv}, {-1});
if (dec->status() != ND4J_STATUS_OK) if (dec.status() != ND4J_STATUS_OK)
return dec->status(); return dec.status();
// and now we subtract two parts of epsilons and pass result out // and now we subtract two parts of epsilons and pass result out
auto pos = dec->at(0); auto pos = dec.at(0);
auto neg = dec->at(1); auto neg = dec.at(1);
pos->applyPairwiseTransform(sd::pairwise::Subtract, *neg, *epsilon); pos->applyPairwiseTransform(sd::pairwise::Subtract, *neg, *epsilon);
delete tmpResult;
delete dec;
return ND4J_STATUS_OK; return ND4J_STATUS_OK;
} }

View File

@ -102,10 +102,12 @@ namespace sd {
REQUIRE_TRUE(block.width() == 1, 0, "Where op takes either 1 or 3 operands, But got %d operands instead", block.width()); REQUIRE_TRUE(block.width() == 1, 0, "Where op takes either 1 or 3 operands, But got %d operands instead", block.width());
// if (output->isEmpty()) // if (output->isEmpty())
Nd4jLong width = condition->rankOf(); Nd4jLong width = condition->rankOf();
sd::ops::Where op; sd::ops::Where op;
std::unique_ptr<ResultSet> res(op.evaluate({condition})); auto res(op.evaluate({condition}));
REQUIRE_OK(res->status()); REQUIRE_OK(res.status());
NDArray* whereTrue = res->at(0); NDArray* whereTrue = res.at(0);
if (whereTrue->isEmpty()) if (whereTrue->isEmpty())
return ND4J_STATUS_OK; return ND4J_STATUS_OK;
for (Nd4jLong outNext = 0; outNext < width; ++outNext) { for (Nd4jLong outNext = 0; outNext < width; ++outNext) {

View File

@ -65,11 +65,12 @@ namespace sd {
auto gradX = OUTPUT_VARIABLE(0); auto gradX = OUTPUT_VARIABLE(0);
auto gradY = OUTPUT_VARIABLE(1); auto gradY = OUTPUT_VARIABLE(1);
gradX->assign(epsNext); gradX->assign(epsNext);
sd::ops::floormod op; sd::ops::floormod op;
std::unique_ptr<ResultSet> tmpResult(op.evaluate({x, y})); auto tmpResult(op.evaluate({x, y}));
if (gradY->rankOf() == gradX->rankOf()) if (gradY->rankOf() == gradX->rankOf())
epsNext->applyPairwiseTransform(pairwise::Multiply, *tmpResult->at(0), *gradY); epsNext->applyPairwiseTransform(pairwise::Multiply, *tmpResult.at(0), *gradY);
else // epsNext is greater than gradY else // epsNext is greater than gradY
{ {
std::vector<Nd4jLong> dims(epsNext->rankOf() * 2); std::vector<Nd4jLong> dims(epsNext->rankOf() * 2);
@ -77,7 +78,7 @@ namespace sd {
for (Nd4jLong d = 0; d < gap; d++) { for (Nd4jLong d = 0; d < gap; d++) {
dims[d * 2 + 1] = 1; dims[d * 2 + 1] = 1;
} }
auto tempIn((*tmpResult->at(0))(dims)); auto tempIn((*tmpResult.at(0))(dims));
(*epsNext)(dims).applyPairwiseTransform(pairwise::Multiply, tempIn, *gradY); (*epsNext)(dims).applyPairwiseTransform(pairwise::Multiply, tempIn, *gradY);
} }
return Status::OK(); return Status::OK();

View File

@ -113,23 +113,21 @@ namespace ops {
originalIndices.linspace(0); originalIndices.linspace(0);
ops::dynamic_partition op; ops::dynamic_partition op;
auto res = op.evaluate({&originalIndices, indices}, {numPartition}); auto res = op.evaluate({&originalIndices, indices}, {numPartition});
REQUIRE_TRUE(res->status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning."); REQUIRE_TRUE(res.status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning.");
ops::dynamic_stitch stichOp; ops::dynamic_stitch stichOp;
std::vector<NDArray*> partitions(numPartition * 2); std::vector<NDArray*> partitions(numPartition * 2);
for (size_t i = 0; i < res->size(); i++) { for (size_t i = 0; i < res.size(); i++) {
partitions[i] = res->at(i); partitions[i] = res.at(i);
partitions[i + numPartition] = gradOutList[i]; partitions[i + numPartition] = gradOutList[i];
} }
auto result = stichOp.evaluate(partitions, {numPartition}); auto result = stichOp.evaluate(partitions, {numPartition});
REQUIRE_TRUE(result->status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning."); REQUIRE_TRUE(result.status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning.");
result->at(0)->reshapei(outputList[0]->getShapeAsVector()); result.at(0)->reshapei(outputList[0]->getShapeAsVector());
outputList[1]->assign(indices); outputList[1]->assign(indices);
outputList[0]->assign(result->at(0)); outputList[0]->assign(result.at(0));
// helpers::dynamicPartitionFunctorBP(block.launchContext(), input, indices, gradOutList, outputList); // helpers::dynamicPartitionFunctorBP(block.launchContext(), input, indices, gradOutList, outputList);
delete res;
delete result;
return ND4J_STATUS_OK; return ND4J_STATUS_OK;
} }

View File

@ -66,10 +66,10 @@ CUSTOM_OP_IMPL(embedding_lookup, 2, 1, false, 0, 1) {
sd::ops::gather op; sd::ops::gather op;
std::unique_ptr<ResultSet> result(op.evaluate({input, indeces}, {0})); auto result(op.evaluate({input, indeces}, {0}));
REQUIRE_TRUE(result->status() == Status::OK(), 0, "embedding_lookup: cannot retrieve results from gather op."); REQUIRE_TRUE(result.status() == Status::OK(), 0, "embedding_lookup: cannot retrieve results from gather op.");
REQUIRE_TRUE(result->at(0)->isSameShape(output), 0, "embedding_lookup: wrong shape of return from gather op."); REQUIRE_TRUE(result.at(0)->isSameShape(output), 0, "embedding_lookup: wrong shape of return from gather op.");
output->assign(result->at(0)); output->assign(result.at(0));
} }
return Status::OK(); return Status::OK();
} }

View File

@ -95,8 +95,8 @@ CUSTOM_OP_IMPL(dynamic_bidirectional_rnn, 7, 4, false, 0, 0) {
// forward steps // forward steps
sd::ops::dynamic_rnn dynamicRnn; sd::ops::dynamic_rnn dynamicRnn;
auto resultsFW = dynamicRnn.evaluate({x, WxFW, WhFW, bFW, h0FW, maxTimeStep}, {timeMajor}); auto resultsFW = dynamicRnn.evaluate({x, WxFW, WhFW, bFW, h0FW, maxTimeStep}, {timeMajor});
hFW->assign(resultsFW->at(0)); // [time x bS x numUnitsFW] or [bS x time x numUnitsFW] hFW->assign(resultsFW.at(0)); // [time x bS x numUnitsFW] or [bS x time x numUnitsFW]
hFWFinal->assign(resultsFW->at(1)); hFWFinal->assign(resultsFW.at(1));
auto seqLen = maxTimeStep; auto seqLen = maxTimeStep;
if(seqLen == nullptr) { if(seqLen == nullptr) {
@ -108,22 +108,17 @@ CUSTOM_OP_IMPL(dynamic_bidirectional_rnn, 7, 4, false, 0, 0) {
// reverse x // reverse x
sd::ops::reverse_sequence reverse; sd::ops::reverse_sequence reverse;
auto resultsIn = timeMajor ? reverse.evaluate({x, seqLen}, {0, 1}) : reverse.evaluate({x, seqLen}, {1, 0}); auto resultsIn = timeMajor ? reverse.evaluate({x, seqLen}, {0, 1}) : reverse.evaluate({x, seqLen}, {1, 0});
REQUIRE_TRUE (resultsIn->status() == ND4J_STATUS_OK, 0, "dynamic_bidirectional_rnn: there is a problem with reverse on the sequence."); REQUIRE_TRUE (resultsIn.status() == ND4J_STATUS_OK, 0, "dynamic_bidirectional_rnn: there is a problem with reverse on the sequence.");
auto revInput = resultsIn->at(0); auto revInput = resultsIn.at(0);
// backward steps // backward steps
auto resultsBW = dynamicRnn.evaluate({revInput, WxBW, WhBW, bBW, h0BW, maxTimeStep}, {timeMajor}); auto resultsBW = dynamicRnn.evaluate({revInput, WxBW, WhBW, bBW, h0BW, maxTimeStep}, {timeMajor});
auto hBWtemp = resultsBW->at(0); // [time x bS x numUnitsBW] or [ bS x time xnumUnitsBW] auto hBWtemp = resultsBW.at(0); // [time x bS x numUnitsBW] or [ bS x time xnumUnitsBW]
hBWFinal->assign(resultsBW->at(1)); hBWFinal->assign(resultsBW.at(1));
// reverse hBWtemp // reverse hBWtemp
auto resultsOut = timeMajor ? reverse.evaluate({hBWtemp, seqLen}, {0, 1}) : reverse.evaluate({hBWtemp, seqLen}, {1, 0}); auto resultsOut = timeMajor ? reverse.evaluate({hBWtemp, seqLen}, {0, 1}) : reverse.evaluate({hBWtemp, seqLen}, {1, 0});
hBW->assign(resultsOut->at(0)); hBW->assign(resultsOut.at(0));
delete resultsOut;
delete resultsBW;
delete resultsIn;
delete resultsFW;
if(seqLen != maxTimeStep) if(seqLen != maxTimeStep)
delete seqLen; delete seqLen;
@ -228,12 +223,6 @@ DECLARE_SHAPE_FN(dynamic_bidirectional_rnn) {
} }
} }
} }

View File

@ -52,14 +52,13 @@ namespace helpers {
sd::ops::unique opUnique; sd::ops::unique opUnique;
auto uResult = opUnique.evaluate({&arrayFull}); auto uResult = opUnique.evaluate({&arrayFull});
if (Status::OK() != uResult->status()) if (Status::OK() != uResult.status())
throw std::runtime_error("multiUnique: cannot execute unique op properly."); throw std::runtime_error("multiUnique: cannot execute unique op properly.");
auto uniqueVals = uResult->at(0); auto uniqueVals = uResult.at(0);
bool res = uniqueVals->lengthOf() == arrayFull.lengthOf(); bool res = uniqueVals->lengthOf() == arrayFull.lengthOf();
delete uResult;
return res; return res;
} }

View File

@ -64,7 +64,7 @@ namespace sd {
block.pushNDArrayListToVariableSpace(block.getNodeId(), 0, arrayList); block.pushNDArrayListToVariableSpace(block.getNodeId(), 0, arrayList);
} }
ResultSet* DeclarableListOp::execute(NDArrayList* list, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs) { ResultSet DeclarableListOp::execute(NDArrayList* list, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs) {
std::vector<NDArray*> ins(inputs); std::vector<NDArray*> ins(inputs);
std::vector<double> tas(tArgs); std::vector<double> tas(tArgs);
std::vector<int> ias(iArgs); std::vector<int> ias(iArgs);
@ -94,7 +94,7 @@ namespace sd {
return status; return status;
} }
ResultSet* DeclarableListOp::execute(NDArrayList* list, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs) { ResultSet DeclarableListOp::execute(NDArrayList* list, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs) {
VariableSpace varSpace; VariableSpace varSpace;
int nodeId = 119; int nodeId = 119;
@ -132,8 +132,8 @@ namespace sd {
Nd4jStatus result = this->validateAndExecute(block); Nd4jStatus result = this->validateAndExecute(block);
auto res = new ResultSet(); ResultSet res;
res->setStatus(result); res.setStatus(result);
for (int e = 0; e < DataTypeUtils::max<int>(); e++) { for (int e = 0; e < DataTypeUtils::max<int>(); e++) {
std::pair<int,int> pair(1, e); std::pair<int,int> pair(1, e);
@ -143,10 +143,10 @@ namespace sd {
auto arr = var->getNDArray(); auto arr = var->getNDArray();
if (arr->isAttached()) { if (arr->isAttached()) {
auto d = arr->detach(); auto d = arr->detach();
res->push_back(d); res.push_back(d);
} else { } else {
var->markRemovable(false); var->markRemovable(false);
res->push_back(arr); res.push_back(arr);
} }
} }
} else } else

View File

@ -962,12 +962,12 @@ namespace sd {
return execute(&ctx); return execute(&ctx);
} }
sd::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs) { sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs) {
return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<sd::DataType>()); return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<sd::DataType>());
} }
template <> template <>
sd::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<int> iArgs) { sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<int> iArgs) {
std::vector<Nd4jLong> realArgs; std::vector<Nd4jLong> realArgs;
for (auto v:iArgs) for (auto v:iArgs)
realArgs.emplace_back(v); realArgs.emplace_back(v);
@ -976,12 +976,12 @@ namespace sd {
} }
template <> template <>
sd::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<Nd4jLong> iArgs) { sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<Nd4jLong> iArgs) {
return evaluate(inputs, std::vector<double>(), iArgs, std::vector<bool>(), std::vector<sd::DataType>()); return evaluate(inputs, std::vector<double>(), iArgs, std::vector<bool>(), std::vector<sd::DataType>());
} }
template <> template <>
sd::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<float> tArgs) { sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<float> tArgs) {
std::vector<double> realArgs; std::vector<double> realArgs;
for (auto v:tArgs) for (auto v:tArgs)
realArgs.emplace_back(v); realArgs.emplace_back(v);
@ -990,21 +990,21 @@ namespace sd {
} }
template <> template <>
sd::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<double> tArgs) { sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<double> tArgs) {
return evaluate(inputs, tArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<sd::DataType>()); return evaluate(inputs, tArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<sd::DataType>());
} }
template <> template <>
sd::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<bool> bArgs) { sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<bool> bArgs) {
return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), bArgs, std::vector<sd::DataType>()); return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), bArgs, std::vector<sd::DataType>());
} }
template <> template <>
sd::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<sd::DataType> bArgs) { sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<sd::DataType> bArgs) {
return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), std::vector<bool>(), bArgs); return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), std::vector<bool>(), bArgs);
} }
sd::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs, const std::vector<sd::DataType> &dArgs, bool isInplace) { sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs, const std::vector<sd::DataType> &dArgs, bool isInplace) {
VariableSpace variableSpace; VariableSpace variableSpace;
//ResultSet arrayList; //ResultSet arrayList;
FlowPath fp; FlowPath fp;
@ -1041,11 +1041,11 @@ namespace sd {
block.getDArguments()->push_back(dArgs.at(e)); block.getDArguments()->push_back(dArgs.at(e));
Nd4jStatus status = this->execute(&block); Nd4jStatus status = this->execute(&block);
auto arrayList = new ResultSet(); ResultSet arrayList;
if (isInplace) if (isInplace)
arrayList->setNonRemovable(); arrayList.setNonRemovable();
arrayList->setStatus(status); arrayList.setStatus(status);
if (status != ND4J_STATUS_OK) if (status != ND4J_STATUS_OK)
return arrayList; return arrayList;
@ -1058,23 +1058,23 @@ namespace sd {
if (!arr->isAttached()) { if (!arr->isAttached()) {
var->markRemovable(false); var->markRemovable(false);
arr->setContext(sd::LaunchContext::defaultContext()); arr->setContext(sd::LaunchContext::defaultContext());
arrayList->push_back(arr); arrayList.push_back(arr);
} else { } else {
arrayList->push_back(arr->detach()); arrayList.push_back(arr->detach());
} }
} else } else
break; break;
} }
} else { } else {
for (auto v:inputs) { for (auto v:inputs) {
arrayList->push_back(v); arrayList.push_back(v);
} }
} }
return arrayList; return arrayList;
} }
sd::ResultSet* sd::ops::DeclarableOp::execute(const sd::OpArgsHolder& holder, bool isInplace) { sd::ResultSet sd::ops::DeclarableOp::execute(const sd::OpArgsHolder& holder, bool isInplace) {
// FIXME: add DArgs to OpArgsHolder // FIXME: add DArgs to OpArgsHolder
return evaluate(holder.getInArrs(), holder.getTArgs(), holder.getIArgs(), holder.getBArgs(), std::vector<sd::DataType>(), isInplace); return evaluate(holder.getInArrs(), holder.getTArgs(), holder.getIArgs(), holder.getBArgs(), std::vector<sd::DataType>(), isInplace);
} }

View File

@ -357,20 +357,20 @@ namespace sd {
return DeclarableOp::execute(block); return DeclarableOp::execute(block);
} }
sd::ResultSet* LegacyRandomOp::execute(sd::graph::RandomGenerator& rng, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs, bool isInplace) { sd::ResultSet LegacyRandomOp::execute(sd::graph::RandomGenerator& rng, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs, bool isInplace) {
std::vector<NDArray*> ins(inputs); std::vector<NDArray*> ins(inputs);
std::vector<double> tas(tArgs); std::vector<double> tas(tArgs);
std::vector<int> ias(iArgs); std::vector<int> ias(iArgs);
return this->execute(rng, ins, tas, ias, isInplace); return this->execute(rng, ins, tas, ias, isInplace);
} }
sd::ResultSet* LegacyRandomOp::execute(sd::graph::RandomGenerator& rng, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs, bool isInplace) { sd::ResultSet LegacyRandomOp::execute(sd::graph::RandomGenerator& rng, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs, bool isInplace) {
VariableSpace variableSpace; VariableSpace variableSpace;
auto arrayList = new ResultSet(); ResultSet arrayList;
//ResultSet arrayList; //ResultSet arrayList;
if (isInplace) if (isInplace)
arrayList->setNonRemovable(); arrayList.setNonRemovable();
int cnt = -1; int cnt = -1;
std::vector<int> in; std::vector<int> in;
@ -398,7 +398,7 @@ namespace sd {
block.getIArguments()->emplace_back(iArgs.at(e)); block.getIArguments()->emplace_back(iArgs.at(e));
Nd4jStatus status = this->execute(&block); Nd4jStatus status = this->execute(&block);
arrayList->setStatus(status); arrayList.setStatus(status);
if (status != ND4J_STATUS_OK) if (status != ND4J_STATUS_OK)
return arrayList; return arrayList;
@ -410,9 +410,9 @@ namespace sd {
auto arr = var->getNDArray(); auto arr = var->getNDArray();
if (!arr->isAttached()) { if (!arr->isAttached()) {
var->markRemovable(false); var->markRemovable(false);
arrayList->push_back(arr); arrayList.push_back(arr);
} else { } else {
arrayList->push_back(arr->detach()); arrayList.push_back(arr->detach());
} }
} else } else
break; break;

View File

@ -44,9 +44,7 @@ TEST_F(AttentionTests, basic_dot_product_attention) {
sd::ops::dot_product_attention op; sd::ops::dot_product_attention op;
auto result = op.evaluate({&queries, &keys, &values}, {1, 0}); auto result = op.evaluate({&queries, &keys, &values}, {1, 0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
delete result;
} }
/* /*
@ -72,9 +70,7 @@ TEST_F(AttentionTests, basic_dot_product_attention_with_weights) {
sd::ops::dot_product_attention op; sd::ops::dot_product_attention op;
auto result = op.evaluate({&queries, &keys, &values}, {1, 1}); auto result = op.evaluate({&queries, &keys, &values}, {1, 1});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
delete result;
} }
TEST_F(AttentionTests, basic_dot_product_attention_with_mask) { TEST_F(AttentionTests, basic_dot_product_attention_with_mask) {
@ -86,9 +82,7 @@ TEST_F(AttentionTests, basic_dot_product_attention_with_mask) {
sd::ops::dot_product_attention op; sd::ops::dot_product_attention op;
auto result = op.evaluate({&queries, &keys, &values, &mask}, {1, 0}); auto result = op.evaluate({&queries, &keys, &values, &mask}, {1, 0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
delete result;
} }
/* /*
@ -118,9 +112,7 @@ TEST_F(AttentionTests, multi_head_input_dot_product_attention_with_mask) {
sd::ops::dot_product_attention op; sd::ops::dot_product_attention op;
auto result = op.evaluate({&queries, &keys, &values, &mask}, {1, 0}); auto result = op.evaluate({&queries, &keys, &values, &mask}, {1, 0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
delete result;
} }
/* /*
@ -154,9 +146,7 @@ TEST_F(AttentionTests, basic_multi_head_dot_product_attention) {
sd::ops::multi_head_dot_product_attention op; sd::ops::multi_head_dot_product_attention op;
auto result = op.evaluate({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo}, {1, 0}); auto result = op.evaluate({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo}, {1, 0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
delete result;
} }
/* /*
@ -198,9 +188,7 @@ TEST_F(AttentionTests, basic_multi_head_dot_product_attention_with_mask) {
sd::ops::multi_head_dot_product_attention op; sd::ops::multi_head_dot_product_attention op;
auto result = op.evaluate({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo, &mask}, {1, 0}); auto result = op.evaluate({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo, &mask}, {1, 0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
delete result;
} }
/* /*

View File

@ -39,13 +39,11 @@ TEST_F(BackpropTests, Test_Add_1) {
sd::ops::add_bp op; sd::ops::add_bp op;
auto result = op.evaluate({&x, &y, &e}); auto result = op.evaluate({&x, &y, &e});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto eps = result->at(0); auto eps = result.at(0);
auto grad = result->at(1); auto grad = result.at(1);
ASSERT_TRUE(x.isSameShape(eps)); ASSERT_TRUE(x.isSameShape(eps));
ASSERT_TRUE(y.isSameShape(grad)); ASSERT_TRUE(y.isSameShape(grad));
delete result;
} }

View File

@ -137,13 +137,12 @@ TEST_F(BooleanOpsTests, test_where_1) {
sd::ops::choose op; sd::ops::choose op;
auto result = op.evaluate({&x, &y}, {3}); auto result = op.evaluate({&x, &y}, {3});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
//z->printIndexedBuffer("z"); //z->printIndexedBuffer("z");
ASSERT_EQ(e, *z); ASSERT_EQ(e, *z);
delete result;
} }

View File

@ -48,9 +48,9 @@ TEST_F(BroadcastableOpsTests, Test_Add_1) {
sd::ops::add op; sd::ops::add op;
auto result = op.evaluate({&x, &y}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
//exp.printIndexedBuffer("E A"); //exp.printIndexedBuffer("E A");
//z->printIndexedBuffer("Z"); //z->printIndexedBuffer("Z");
@ -58,7 +58,6 @@ TEST_F(BroadcastableOpsTests, Test_Add_1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
@ -75,14 +74,12 @@ TEST_F(BroadcastableOpsTests, Test_Multiply_1) {
sd::ops::multiply op; sd::ops::multiply op;
auto result = op.evaluate({&x, &y}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
@ -100,14 +97,13 @@ TEST_F(BroadcastableOpsTests, Test_SquaredSubtract_1) {
sd::ops::squaredsubtract op; sd::ops::squaredsubtract op;
auto result = op.evaluate({&x, &y}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
@ -119,14 +115,12 @@ TEST_F(BroadcastableOpsTests, Test_ScalarBroadcast_1) {
sd::ops::subtract op; sd::ops::subtract op;
auto result = op.evaluate({&x, &y}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
@ -138,14 +132,12 @@ TEST_F(BroadcastableOpsTests, Test_ScalarBroadcast_2) {
sd::ops::add op; sd::ops::add op;
auto result = op.evaluate({&x, &y}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
@ -156,14 +148,12 @@ TEST_F(BroadcastableOpsTests, Test_Maximum_1) {
sd::ops::maximum op; sd::ops::maximum op;
auto result = op.evaluate({&x, &row}); auto result = op.evaluate({&x, &row});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
@ -174,15 +164,14 @@ TEST_F(BroadcastableOpsTests, Test_Minimum_1) {
sd::ops::minimum op; sd::ops::minimum op;
auto result = op.evaluate({&x, &col}); auto result = op.evaluate({&x, &col});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
@ -283,14 +272,13 @@ TEST_F(BroadcastableOpsTests, Test_Scalar_Add_1) {
sd::ops::add op; sd::ops::add op;
auto result = op.evaluate({&x, &y}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
@ -333,11 +321,9 @@ TEST_F(BroadcastableOpsTests, Test_Subtract_2) {
sd::ops::subtract op; sd::ops::subtract op;
auto result = op.evaluate({&x, &y}); auto result = op.evaluate({&x, &y});
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(e.equalsTo(z)); ASSERT_TRUE(e.equalsTo(z));
delete result;
} }
TEST_F(BroadcastableOpsTests, Test_Subtract_3) { TEST_F(BroadcastableOpsTests, Test_Subtract_3) {
@ -511,13 +497,12 @@ TEST_F(BroadcastableOpsTests, Test_Multiply_7) {
sd::ops::multiply op; sd::ops::multiply op;
auto result = op.evaluate({&x, &y}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(e.equalsTo(z)); ASSERT_TRUE(e.equalsTo(z));
delete result;
} }
TEST_F(BroadcastableOpsTests, Test_Multiply_8) { TEST_F(BroadcastableOpsTests, Test_Multiply_8) {
@ -527,13 +512,11 @@ TEST_F(BroadcastableOpsTests, Test_Multiply_8) {
sd::ops::multiply op; sd::ops::multiply op;
auto result = op.evaluate({&x, &y}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(e.equalsTo(z)); ASSERT_TRUE(e.equalsTo(z));
delete result;
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -606,14 +589,12 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_3) {
sd::ops::maximum op; sd::ops::maximum op;
auto result = op.evaluate({&x, &y}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(*z)); ASSERT_TRUE(e.equalsTo(*z));
delete result;
} }
TEST_F(BroadcastableOpsTests, broadcast_empty_4) { TEST_F(BroadcastableOpsTests, broadcast_empty_4) {
@ -625,14 +606,13 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_4) {
sd::ops::maximum op; sd::ops::maximum op;
auto result = op.evaluate({&x, &y}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(*z)); ASSERT_TRUE(e.equalsTo(*z));
delete result;
} }
TEST_F(BroadcastableOpsTests, broadcast_empty_5) { TEST_F(BroadcastableOpsTests, broadcast_empty_5) {
@ -644,14 +624,13 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_5) {
sd::ops::realdiv op; sd::ops::realdiv op;
auto result = op.evaluate({&x, &y}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(*z)); ASSERT_TRUE(e.equalsTo(*z));
delete result;
} }
TEST_F(BroadcastableOpsTests, broadcast_empty_6) { TEST_F(BroadcastableOpsTests, broadcast_empty_6) {
@ -663,14 +642,13 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_6) {
sd::ops::realdiv op; sd::ops::realdiv op;
auto result = op.evaluate({&x, &y}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(*z)); ASSERT_TRUE(e.equalsTo(*z));
delete result;
} }
TEST_F(BroadcastableOpsTests, broadcast_empty_7) { TEST_F(BroadcastableOpsTests, broadcast_empty_7) {
@ -682,14 +660,12 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_7) {
sd::ops::realdiv op; sd::ops::realdiv op;
auto result = op.evaluate({&x, &y}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(*z)); ASSERT_TRUE(e.equalsTo(*z));
delete result;
} }
@ -718,15 +694,13 @@ TEST_F(BroadcastableOpsTests, broadcast_bool_empty_2) {
sd::ops::greater op; sd::ops::greater op;
auto result = op.evaluate({&x, &y}); auto result = op.evaluate({&x, &y});
auto z = result->at(0); auto z = result.at(0);
// z->printShapeInfo("z"); // z->printShapeInfo("z");
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(*z)); ASSERT_TRUE(e.equalsTo(*z));
delete result;
} }
TEST_F(BroadcastableOpsTests, broadcast_bool_1) { TEST_F(BroadcastableOpsTests, broadcast_bool_1) {

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -48,9 +48,7 @@ TEST_F(DataTypesValidationTests, Basic_Test_1) {
sd::ops::conv2d op; sd::ops::conv2d op;
auto result = op.evaluate({&input, &weights}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}); auto result = op.evaluate({&input, &weights}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0});
ASSERT_EQ(ND4J_STATUS_VALIDATION, result->status()); ASSERT_EQ(ND4J_STATUS_VALIDATION, result.status());
delete result;
} }
TEST_F(DataTypesValidationTests, Basic_Test_2) { TEST_F(DataTypesValidationTests, Basic_Test_2) {
@ -63,13 +61,12 @@ TEST_F(DataTypesValidationTests, Basic_Test_2) {
sd::ops::conv2d op; sd::ops::conv2d op;
auto result = op.evaluate({&input, &weights}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}); auto result = op.evaluate({&input, &weights}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -47,13 +47,11 @@ TEST_F(DeclarableOpsTests16, scatter_upd_1) {
sd::ops::scatter_upd op; sd::ops::scatter_upd op;
auto result = op.evaluate({ &x, &y, &w }); auto result = op.evaluate({ &x, &y, &w });
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_EQ(e, *z); ASSERT_EQ(e, *z);
delete result;
} }
TEST_F(DeclarableOpsTests16, scatter_upd_2) { TEST_F(DeclarableOpsTests16, scatter_upd_2) {
@ -67,13 +65,11 @@ TEST_F(DeclarableOpsTests16, scatter_upd_2) {
sd::ops::scatter_upd op; sd::ops::scatter_upd op;
auto result = op.evaluate({ &x, &indices, &updates }); auto result = op.evaluate({ &x, &indices, &updates });
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_EQ(e, *z); ASSERT_EQ(e, *z);
delete result;
} }
TEST_F(DeclarableOpsTests16, scatter_upd_3) { TEST_F(DeclarableOpsTests16, scatter_upd_3) {
@ -136,13 +132,11 @@ TEST_F(DeclarableOpsTests16, test_hamming_distance_1) {
sd::ops::bits_hamming_distance op; sd::ops::bits_hamming_distance op;
auto result = op.evaluate({ &x, &y }); auto result = op.evaluate({ &x, &y });
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_EQ(e, *z); ASSERT_EQ(e, *z);
delete result;
} }
TEST_F(DeclarableOpsTests16, test_knn_mindistance_1) { TEST_F(DeclarableOpsTests16, test_knn_mindistance_1) {
@ -167,10 +161,8 @@ TEST_F(DeclarableOpsTests16, test_empty_cast_1) {
sd::ops::cast op; sd::ops::cast op;
auto result = op.evaluate({&x}, {10}); auto result = op.evaluate({&x}, {10});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
ASSERT_EQ(e, *result->at(0)); ASSERT_EQ(e, *result.at(0));
delete result;
} }
TEST_F(DeclarableOpsTests16, test_range_1) { TEST_F(DeclarableOpsTests16, test_range_1) {
@ -717,8 +709,6 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_1) {
} }
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_2) { TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_2) {
auto rgb = NDArrayFactory::create<float>('c', { 5, 3, 4 }, auto rgb = NDArrayFactory::create<float>('c', { 5, 3, 4 },
@ -767,7 +757,6 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_2) {
} }
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_3) { TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_3) {
auto rgb = NDArrayFactory::create<float>('c', { 4, 3 }, auto rgb = NDArrayFactory::create<float>('c', { 4, 3 },
@ -798,7 +787,6 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_3) {
} }
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_4) { TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_4) {
auto rgb = NDArrayFactory::create<float>('c', { 3, 4 }, auto rgb = NDArrayFactory::create<float>('c', { 3, 4 },
@ -826,7 +814,6 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_4) {
ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual)); ASSERT_TRUE(expected.equalsTo(actual));
} }
@ -850,7 +837,6 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_5) {
ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual)); ASSERT_TRUE(expected.equalsTo(actual));
} }
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_6) { TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_6) {
@ -891,8 +877,6 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_6) {
} }
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_1) { TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_1) {
auto yiqs = NDArrayFactory::create<float>('c', { 5, 4, 3 }, { auto yiqs = NDArrayFactory::create<float>('c', { 5, 4, 3 }, {
@ -937,8 +921,6 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_1) {
} }
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_2) { TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_2) {
auto yiqs = NDArrayFactory::create<float>('c', { 5, 3, 4 }, { auto yiqs = NDArrayFactory::create<float>('c', { 5, 3, 4 }, {
@ -983,7 +965,6 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_2) {
} }
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_3) { TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_3) {
auto yiqs = NDArrayFactory::create<float>('c', { 4, 3 }, { auto yiqs = NDArrayFactory::create<float>('c', { 4, 3 }, {
@ -1010,7 +991,6 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_3) {
} }
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_4) { TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_4) {
auto yiqs = NDArrayFactory::create<float>('c', { 3, 4 }, { auto yiqs = NDArrayFactory::create<float>('c', { 3, 4 }, {
@ -1037,8 +1017,6 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_4) {
} }
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_5) { TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_5) {
auto yiqs = NDArrayFactory::create<float>('c', { 3 }, { auto yiqs = NDArrayFactory::create<float>('c', { 3 }, {
@ -1061,7 +1039,6 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_5) {
#endif #endif
ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual)); ASSERT_TRUE(expected.equalsTo(actual));
} }
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_6) { TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_6) {
@ -1096,5 +1073,4 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_6) {
ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual)); ASSERT_TRUE(expected.equalsTo(actual));
} }

View File

@ -49,9 +49,7 @@ TEST_F(DeclarableOpsTests17, test_sparse_to_dense_1) {
sd::ops::compat_sparse_to_dense op; sd::ops::compat_sparse_to_dense op;
auto result = op.evaluate({&ranges, &shape, &values, &def}); auto result = op.evaluate({&ranges, &shape, &values, &def});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
delete result;
} }
TEST_F(DeclarableOpsTests17, test_sparse_to_dense_2) { TEST_F(DeclarableOpsTests17, test_sparse_to_dense_2) {
@ -64,9 +62,8 @@ TEST_F(DeclarableOpsTests17, test_sparse_to_dense_2) {
sd::ops::compat_sparse_to_dense op; sd::ops::compat_sparse_to_dense op;
auto result = op.evaluate({&ranges, &shape, &values, &def}); auto result = op.evaluate({&ranges, &shape, &values, &def});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
delete result;
} }
TEST_F(DeclarableOpsTests17, test_compat_string_split_1) { TEST_F(DeclarableOpsTests17, test_compat_string_split_1) {
@ -78,11 +75,11 @@ TEST_F(DeclarableOpsTests17, test_compat_string_split_1) {
sd::ops::compat_string_split op; sd::ops::compat_string_split op;
auto result = op.evaluate({&x, &delimiter}); auto result = op.evaluate({&x, &delimiter});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
ASSERT_EQ(2, result->size()); ASSERT_EQ(2, result.size());
auto z0 = result->at(0); auto z0 = result.at(0);
auto z1 = result->at(1); auto z1 = result.at(1);
ASSERT_TRUE(exp0.isSameShape(z0)); ASSERT_TRUE(exp0.isSameShape(z0));
ASSERT_TRUE(exp1.isSameShape(z1)); ASSERT_TRUE(exp1.isSameShape(z1));
@ -90,5 +87,4 @@ TEST_F(DeclarableOpsTests17, test_compat_string_split_1) {
ASSERT_EQ(exp0, *z0); ASSERT_EQ(exp0, *z0);
ASSERT_EQ(exp1, *z1); ASSERT_EQ(exp1, *z1);
delete result;
} }

View File

@ -62,10 +62,8 @@ TEST_F(DeclarableOpsTests19, test_conv1d_bp_1) {
sd::ops::conv1d_bp op; sd::ops::conv1d_bp op;
auto result = op.evaluate({&t, &u, &v}, {3, 2, 0, 1, 2,0}); auto result = op.evaluate({&t, &u, &v}, {3, 2, 0, 1, 2,0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
delete result;
} }
TEST_F(DeclarableOpsTests19, test_squeeze_1) { TEST_F(DeclarableOpsTests19, test_squeeze_1) {

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -51,14 +51,12 @@ TEST_F(DeclarableOpsTestsCuda1, Test_CHOOSE_SCALAR_LARGE) {
sd::ops::choose op; sd::ops::choose op;
//greater than test //greater than test
auto result = op.evaluate({&x}, {0.0},{3}); auto result = op.evaluate({&x}, {0.0},{3});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(1); auto z = result.at(1);
ASSERT_EQ(148,z->e<double>(0)); ASSERT_EQ(148,z->e<double>(0));
//ASSERT_TRUE(exp.isSameShape(z)); //ASSERT_TRUE(exp.isSameShape(z));
delete result;
} }
/* /*

View File

@ -67,9 +67,9 @@ TEST_F(EmptyTests, Test_Concat_1) {
sd::ops::concat op; sd::ops::concat op;
auto result = op.evaluate({empty, vector}, {}, {0}); auto result = op.evaluate({empty, vector}, {}, {0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
// z->printShapeInfo("z shape"); // z->printShapeInfo("z shape");
// z->printIndexedBuffer("z buffr"); // z->printIndexedBuffer("z buffr");
@ -78,7 +78,6 @@ TEST_F(EmptyTests, Test_Concat_1) {
delete empty; delete empty;
delete vector; delete vector;
delete result;
} }
@ -92,9 +91,9 @@ TEST_F(EmptyTests, Test_Concat_2) {
sd::ops::concat op; sd::ops::concat op;
auto result = op.evaluate({empty, scalar1, scalar2}, {}, {0}); auto result = op.evaluate({empty, scalar1, scalar2}, {}, {0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
// z->printShapeInfo("z shape"); // z->printShapeInfo("z shape");
// z->printIndexedBuffer("z buffr"); // z->printIndexedBuffer("z buffr");
@ -104,7 +103,6 @@ TEST_F(EmptyTests, Test_Concat_2) {
delete empty; delete empty;
delete scalar1; delete scalar1;
delete scalar2; delete scalar2;
delete result;
} }
TEST_F(EmptyTests, Test_Concat_3) { TEST_F(EmptyTests, Test_Concat_3) {
@ -117,13 +115,12 @@ TEST_F(EmptyTests, Test_Concat_3) {
sd::ops::concat op; sd::ops::concat op;
auto result = op.evaluate({&empty, &scalar1, &scalar2}, {}, {0}); auto result = op.evaluate({&empty, &scalar1, &scalar2}, {}, {0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_EQ(exp, *z); ASSERT_EQ(exp, *z);
delete result;
} }
TEST_F(EmptyTests, Test_Concat_4) { TEST_F(EmptyTests, Test_Concat_4) {
@ -136,13 +133,11 @@ TEST_F(EmptyTests, Test_Concat_4) {
sd::ops::concat op; sd::ops::concat op;
auto result = op.evaluate({&scalar1, &empty, &scalar2}, {}, {0}); auto result = op.evaluate({&scalar1, &empty, &scalar2}, {}, {0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_EQ(exp, *z); ASSERT_EQ(exp, *z);
delete result;
} }
TEST_F(EmptyTests, Test_Reshape_1) { TEST_F(EmptyTests, Test_Reshape_1) {
@ -153,12 +148,11 @@ TEST_F(EmptyTests, Test_Reshape_1) {
sd::ops::reshape op; sd::ops::reshape op;
auto result = op.evaluate({&vector, empty}, {}, {}); auto result = op.evaluate({&vector, empty}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
ASSERT_EQ(exp, *result->at(0)); ASSERT_EQ(exp, *result.at(0));
delete empty; delete empty;
delete result;
} }
TEST_F(EmptyTests, Test_Reshape_3) { TEST_F(EmptyTests, Test_Reshape_3) {
@ -168,14 +162,13 @@ TEST_F(EmptyTests, Test_Reshape_3) {
sd::ops::reshape op; sd::ops::reshape op;
auto result = op.evaluate({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.isSameShape(z));
ASSERT_EQ(e, *z); ASSERT_EQ(e, *z);
delete result;
} }
TEST_F(EmptyTests, Test_dup_1) { TEST_F(EmptyTests, Test_dup_1) {
@ -197,12 +190,11 @@ TEST_F(EmptyTests, test_empty_scatter_1) {
sd::ops::scatter_upd op; sd::ops::scatter_upd op;
auto result = op.evaluate({&x, &indices, &updates}, {}, {}, {true}); auto result = op.evaluate({&x, &indices, &updates}, {}, {}, {true});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_EQ(x, *z); ASSERT_EQ(x, *z);
delete result;
} }
TEST_F(EmptyTests, test_empty_scatter_2) { TEST_F(EmptyTests, test_empty_scatter_2) {
@ -288,17 +280,15 @@ TEST_F(EmptyTests, test_empty_reshape_1) {
sd::ops::reshape op; sd::ops::reshape op;
auto result0 = op.evaluate({&x0, &shape0}, {}, {}); auto result0 = op.evaluate({&x0, &shape0}, {}, {});
ASSERT_EQ(Status::OK(), result0->status()); ASSERT_EQ(Status::OK(), result0.status());
auto z0 = result0->at(0); auto z0 = result0.at(0);
ASSERT_EQ(e0, *z0); ASSERT_EQ(e0, *z0);
auto result1 = op.evaluate({&x1, &shape1}, {}, {}); auto result1 = op.evaluate({&x1, &shape1}, {}, {});
ASSERT_EQ(Status::OK(), result1->status()); ASSERT_EQ(Status::OK(), result1.status());
auto z1 = result1->at(0); auto z1 = result1.at(0);
ASSERT_EQ(e1, *z1); ASSERT_EQ(e1, *z1);
delete result0;
delete result1;
} }
@ -309,12 +299,11 @@ TEST_F(EmptyTests, test_empty_matmul_1) {
sd::ops::matmul op; sd::ops::matmul op;
auto result = op.evaluate({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_EQ(e, *z); ASSERT_EQ(e, *z);
delete result;
} }
TEST_F(EmptyTests, test_empty_matmul_2) { TEST_F(EmptyTests, test_empty_matmul_2) {
@ -324,10 +313,8 @@ TEST_F(EmptyTests, test_empty_matmul_2) {
sd::ops::matmul op; sd::ops::matmul op;
auto result = op.evaluate({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_EQ(e, *z); ASSERT_EQ(e, *z);
delete result;
} }

View File

@ -1889,20 +1889,19 @@ TEST_F(HelpersTests1, OpArgsHolder_test3) {
OpArgsHolder holderFF({&input}, {}, {2, 3}); OpArgsHolder holderFF({&input}, {}, {2, 3});
sd::ops::tile opFF; // the kind of op doesn't matter, we simply check here whether op.execute() works with OpArgsHolder correctly sd::ops::tile opFF; // the kind of op doesn't matter, we simply check here whether op.execute() works with OpArgsHolder correctly
auto results = opFF.execute(holderFF); auto results = opFF.execute(holderFF);
auto tiled = results->at(0); auto tiled = results.at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results.status());
ASSERT_TRUE(exp.isSameShape(tiled)); ASSERT_TRUE(exp.isSameShape(tiled));
ASSERT_TRUE(exp.equalsTo(tiled)); ASSERT_TRUE(exp.equalsTo(tiled));
delete results;
OpArgsHolder holderBP = holderFF.createArgsHolderForBP({&gradO}, true); OpArgsHolder holderBP = holderFF.createArgsHolderForBP({&gradO}, true);
sd::ops::tile_bp opBP; sd::ops::tile_bp opBP;
results = opBP.execute(holderBP); results = opBP.execute(holderBP);
auto gradI = results->at(0); auto gradI = results.at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results.status());
ASSERT_TRUE(gradIExp.isSameShape(gradI)); ASSERT_TRUE(gradIExp.isSameShape(gradI));
ASSERT_TRUE(gradIExp.equalsTo(gradI)); ASSERT_TRUE(gradIExp.equalsTo(gradI));
delete results;
} }

View File

@ -47,13 +47,13 @@ TEST_F(IndexingTests, StridedSlice_1) {
sd::ops::strided_slice op; sd::ops::strided_slice op;
auto result = op.evaluate({&x, &begin, &end, &strides}, {}, {0,0,0,0,0}); //, 2,2,0, 3,3,3, 1,1,1}); auto result = op.evaluate({&x, &begin, &end, &strides}, {}, {0,0,0,0,0}); //, 2,2,0, 3,3,3, 1,1,1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
@ -66,14 +66,14 @@ TEST_F(IndexingTests, StridedSlice_2) {
sd::ops::strided_slice op; sd::ops::strided_slice op;
auto result = op.evaluate({&x}, {}, {0,0,0,0,0, 3,2,0, 5,5,3, 1,1,1}); auto result = op.evaluate({&x}, {}, {0,0,0,0,0, 3,2,0, 5,5,3, 1,1,1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
@ -86,14 +86,14 @@ TEST_F(IndexingTests, StridedSlice_3) {
sd::ops::strided_slice op; sd::ops::strided_slice op;
auto result = op.evaluate({&x}, {}, {0,0,0,0,0, 3,2,0, 5,5,3, 1,1,2}); auto result = op.evaluate({&x}, {}, {0,0,0,0,0, 3,2,0, 5,5,3, 1,1,2});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
@ -109,15 +109,15 @@ TEST_F(IndexingTests, SimpleSlice_1) {
sd::ops::slice op; sd::ops::slice op;
auto result = op.evaluate({&input}, {}, {1,0,0, 1,1,3}); auto result = op.evaluate({&input}, {}, {1,0,0, 1,1,3});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
@ -135,15 +135,15 @@ TEST_F(IndexingTests, SimpleSlice_2) {
sd::ops::slice op; sd::ops::slice op;
auto result = op.evaluate({&input}, {}, {1,0,0, 1,2,3}); auto result = op.evaluate({&input}, {}, {1,0,0, 1,2,3});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
TEST_F(IndexingTests, SimpleSlice_3) { TEST_F(IndexingTests, SimpleSlice_3) {
@ -160,15 +160,15 @@ TEST_F(IndexingTests, SimpleSlice_3) {
sd::ops::slice op; sd::ops::slice op;
auto result = op.evaluate({&input}, {}, {1,0,0, 2,1,3}); auto result = op.evaluate({&input}, {}, {1,0,0, 2,1,3});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
TEST_F(IndexingTests, SimpleSlice_4) { TEST_F(IndexingTests, SimpleSlice_4) {
@ -180,14 +180,14 @@ TEST_F(IndexingTests, SimpleSlice_4) {
sd::ops::slice op; sd::ops::slice op;
auto result = op.evaluate({&input, &start, &stop}); auto result = op.evaluate({&input, &start, &stop});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
@ -204,16 +204,16 @@ TEST_F(IndexingTests, MaskedSlice_0) {
sd::ops::strided_slice op; sd::ops::strided_slice op;
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,0, 1, 2, 1}); auto result = op.evaluate({&matrix}, {}, {0,0,0,0,0, 1, 2, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
// z->printShapeInfo("z"); // z->printShapeInfo("z");
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
@ -230,14 +230,14 @@ TEST_F(IndexingTests, MaskedSlice_00) {
sd::ops::strided_slice op; sd::ops::strided_slice op;
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,0, 1, 1, 2, 3, 1, 1}); auto result = op.evaluate({&matrix}, {}, {0,0,0,0,0, 1, 1, 2, 3, 1, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
@ -254,16 +254,16 @@ TEST_F(IndexingTests, MaskedSlice_1) {
sd::ops::strided_slice op; sd::ops::strided_slice op;
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,1, 1, 2, 1}); auto result = op.evaluate({&matrix}, {}, {0,0,0,0,1, 1, 2, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
// z->printShapeInfo("z"); // z->printShapeInfo("z");
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
TEST_F(IndexingTests, MaskedSlice_2) { TEST_F(IndexingTests, MaskedSlice_2) {
@ -275,14 +275,14 @@ TEST_F(IndexingTests, MaskedSlice_2) {
sd::ops::strided_slice op; sd::ops::strided_slice op;
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,1, 1, 0, 0, 3, 3, 3, 1, 1, 1}); auto result = op.evaluate({&matrix}, {}, {0,0,0,0,1, 1, 0, 0, 3, 3, 3, 1, 1, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
@ -295,14 +295,14 @@ TEST_F(IndexingTests, MaskedSlice_3) {
sd::ops::strided_slice op; sd::ops::strided_slice op;
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,2, 1, 0, 0, 3, 3, 3, 1, 1, 1}); auto result = op.evaluate({&matrix}, {}, {0,0,0,0,2, 1, 0, 0, 3, 3, 3, 1, 1, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
@ -315,15 +315,15 @@ TEST_F(IndexingTests, MaskedSlice_4) {
sd::ops::strided_slice op; sd::ops::strided_slice op;
auto result = op.evaluate({&matrix}, {}, {0,0,0,0, 3, 1, 0, 0, 3, 3, 3, 1, 1, 1}); auto result = op.evaluate({&matrix}, {}, {0,0,0,0, 3, 1, 0, 0, 3, 3, 3, 1, 1, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
TEST_F(IndexingTests, Live_Slice_1) { TEST_F(IndexingTests, Live_Slice_1) {
@ -338,16 +338,16 @@ TEST_F(IndexingTests, Live_Slice_1) {
sd::ops::strided_slice op; sd::ops::strided_slice op;
auto result = op.evaluate({&matrix, &begin, &end, &stride}, {}, {0,0,0,0,3}); auto result = op.evaluate({&matrix, &begin, &end, &stride}, {}, {0,0,0,0,3});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
// z->printShapeInfo("z shape"); // z->printShapeInfo("z shape");
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
@ -361,14 +361,14 @@ TEST_F(IndexingTests, Test_StridedSlice_1) {
sd::ops::strided_slice op; sd::ops::strided_slice op;
auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
TEST_F(IndexingTests, Test_StridedSlice_2) { TEST_F(IndexingTests, Test_StridedSlice_2) {
@ -381,16 +381,16 @@ TEST_F(IndexingTests, Test_StridedSlice_2) {
sd::ops::strided_slice op; sd::ops::strided_slice op;
auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
// z->printIndexedBuffer("Z"); // z->printIndexedBuffer("Z");
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
@ -404,14 +404,14 @@ TEST_F(IndexingTests, Test_StridedSlice_3) {
sd::ops::strided_slice op; sd::ops::strided_slice op;
auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
@ -426,16 +426,16 @@ TEST_F(IndexingTests, Test_StridedSlice_4) {
auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
// auto result = op.execute({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1, 0, 1, 1}); // auto result = op.execute({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1, 0, 1, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
//z->printIndexedBuffer("Z"); //z->printIndexedBuffer("Z");
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
TEST_F(IndexingTests, Test_Subarray_Strided_1) { TEST_F(IndexingTests, Test_Subarray_Strided_1) {
@ -458,13 +458,13 @@ TEST_F(IndexingTests, MaskedSlice_5) {
sd::ops::strided_slice<float> op; sd::ops::strided_slice<float> op;
auto result = op.execute({&matrix}, {}, {0,0,0,0,2, 1, 0, 0, 3, 3, 3}); auto result = op.execute({&matrix}, {}, {0,0,0,0,2, 1, 0, 0, 3, 3, 3});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
*/ */

View File

@ -64,13 +64,13 @@ TEST_F(LegacyOpsTests, TransformTests_2) {
sd::ops::LegacyTransformSameOp op(transform::Neg); // Neg sd::ops::LegacyTransformSameOp op(transform::Neg); // Neg
auto result = op.evaluate({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(1, result->size()); ASSERT_EQ(1, result.size());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
TEST_F(LegacyOpsTests, Reciprocal_1) { TEST_F(LegacyOpsTests, Reciprocal_1) {
@ -121,12 +121,12 @@ TEST_F(LegacyOpsTests, PWT_Tests_2) {
sd::ops::LegacyPairwiseTransformOp op(pairwise::Multiply); // Multiply sd::ops::LegacyPairwiseTransformOp op(pairwise::Multiply); // Multiply
auto result = op.evaluate({&x, &y}, {}, {}); auto result = op.evaluate({&x, &y}, {}, {});
auto z = result->at(0); auto z = result.at(0);
//z->printBuffer("Z"); //z->printBuffer("Z");
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
TEST_F(LegacyOpsTests, Scalar_Test_1) { TEST_F(LegacyOpsTests, Scalar_Test_1) {
@ -154,10 +154,10 @@ TEST_F(LegacyOpsTests, Scalar_Test_2) {
sd::ops::LegacyScalarOp op(scalar::Add, y); sd::ops::LegacyScalarOp op(scalar::Add, y);
auto result = op.evaluate({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
@ -169,14 +169,14 @@ TEST_F(LegacyOpsTests, ReduceTests_1) {
auto result = op.evaluate({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(1, result->size()); ASSERT_EQ(1, result.size());
auto z = result->at(0); auto z = result.at(0);
// z->printBuffer("ReduceTest1"); // z->printBuffer("ReduceTest1");
ASSERT_TRUE(z->isScalar()); ASSERT_TRUE(z->isScalar());
ASSERT_NEAR(x.sumNumber().e<float>(0), z->e<float>(0), 1e-5f); ASSERT_NEAR(x.sumNumber().e<float>(0), z->e<float>(0), 1e-5f);
delete result;
} }
@ -188,16 +188,16 @@ TEST_F(LegacyOpsTests, ReduceTests_2) {
auto axis = NDArrayFactory::create<Nd4jLong>('c', {1}, {1}); auto axis = NDArrayFactory::create<Nd4jLong>('c', {1}, {1});
auto result = op.evaluate({&x, &axis}, {}, {}); auto result = op.evaluate({&x, &axis}, {}, {});
ASSERT_EQ(1, result->size()); ASSERT_EQ(1, result.size());
auto z = result->at(0); auto z = result.at(0);
auto exp = x.reduceAlongDimension(reduce::Sum, {1}); auto exp = x.reduceAlongDimension(reduce::Sum, {1});
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
@ -209,15 +209,15 @@ TEST_F(LegacyOpsTests, ReduceTests_3) {
sd::ops::LegacyReduceSameOp op(reduce::Sum); sd::ops::LegacyReduceSameOp op(reduce::Sum);
auto result = op.evaluate({&x, &indices}, {}, {}); auto result = op.evaluate({&x, &indices}, {}, {});
auto z = result->at(0); auto z = result.at(0);
auto exp = x.reduceAlongDimension(reduce::Sum,{1}); auto exp = x.reduceAlongDimension(reduce::Sum,{1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
@ -229,16 +229,16 @@ TEST_F(LegacyOpsTests, ReduceTests_4) {
sd::ops::LegacyReduceSameOp op(reduce::Sum); sd::ops::LegacyReduceSameOp op(reduce::Sum);
auto result = op.evaluate({&x, &indices}, {}, {}, {true}); auto result = op.evaluate({&x, &indices}, {}, {}, {true});
auto z = result->at(0); auto z = result.at(0);
auto exp = x.reduceAlongDimension(reduce::Sum, {1}, true); auto exp = x.reduceAlongDimension(reduce::Sum, {1}, true);
// indices.printShapeInfo("Indices shape"); // indices.printShapeInfo("Indices shape");
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
// z->printIndexedBuffer("Output reduce 4"); // z->printIndexedBuffer("Output reduce 4");
// exp.printIndexedBuffer("Expected reduce 4"); // exp.printIndexedBuffer("Expected reduce 4");
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
TEST_F(LegacyOpsTests, ReduceTests_5) { TEST_F(LegacyOpsTests, ReduceTests_5) {
@ -249,14 +249,14 @@ TEST_F(LegacyOpsTests, ReduceTests_5) {
auto result = op.evaluate({&x}); auto result = op.evaluate({&x});
ASSERT_EQ(1, result->size()); ASSERT_EQ(1, result.size());
auto z = result->at(0); auto z = result.at(0);
// z->printBuffer("ReduceTest1"); // z->printBuffer("ReduceTest1");
ASSERT_TRUE(z->isScalar()); ASSERT_TRUE(z->isScalar());
ASSERT_NEAR(x.meanNumber().e<float>(0), z->e<float>(0), 1e-5f); ASSERT_NEAR(x.meanNumber().e<float>(0), z->e<float>(0), 1e-5f);
delete result;
} }
@ -268,16 +268,16 @@ TEST_F(LegacyOpsTests, ReduceTests_6) {
auto result = op.evaluate({&x, &axis}, {}, {}); auto result = op.evaluate({&x, &axis}, {}, {});
ASSERT_EQ(1, result->size()); ASSERT_EQ(1, result.size());
auto z = result->at(0); auto z = result.at(0);
auto exp = x.reduceAlongDimension(reduce::Mean, {1}); auto exp = x.reduceAlongDimension(reduce::Mean, {1});
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
@ -289,15 +289,15 @@ TEST_F(LegacyOpsTests, ReduceTests_7) {
sd::ops::LegacyReduceFloatOp op(reduce::Mean); sd::ops::LegacyReduceFloatOp op(reduce::Mean);
auto result = op.evaluate({&x, &indices}, {}, {}); auto result = op.evaluate({&x, &indices}, {}, {});
auto z = result->at(0); auto z = result.at(0);
auto exp = x.reduceAlongDimension(reduce::Mean,{1}); auto exp = x.reduceAlongDimension(reduce::Mean,{1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
@ -309,17 +309,17 @@ TEST_F(LegacyOpsTests, ReduceTests_8) {
sd::ops::LegacyReduceFloatOp op(reduce::Mean); sd::ops::LegacyReduceFloatOp op(reduce::Mean);
auto result = op.evaluate({&x, &indices}, {}, {}, {true}); auto result = op.evaluate({&x, &indices}, {}, {}, {true});
auto z = result->at(0); auto z = result.at(0);
auto exp = x.reduceAlongDimension(reduce::Mean, {1}, true); auto exp = x.reduceAlongDimension(reduce::Mean, {1}, true);
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
// z->printIndexedBuffer("Reduce8 output"); // z->printIndexedBuffer("Reduce8 output");
// z->printShapeInfo("Reduce8 shape"); // z->printShapeInfo("Reduce8 shape");
// exp.printShapeInfo("Reduce8 expected shape"); // exp.printShapeInfo("Reduce8 expected shape");
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
@ -331,14 +331,14 @@ TEST_F(LegacyOpsTests, IndexReduceTests_1) {
auto result = op.evaluate({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(1, result->size()); ASSERT_EQ(1, result.size());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(z->isScalar()); ASSERT_TRUE(z->isScalar());
ASSERT_EQ(24, z->e<int>(0)); ASSERT_EQ(24, z->e<int>(0));
delete result;
} }
@ -351,9 +351,9 @@ TEST_F(LegacyOpsTests, IndexReduceTests_2) {
auto result = op.evaluate({&x, &indices}, {}, {}); auto result = op.evaluate({&x, &indices}, {}, {});
ASSERT_EQ(1, result->size()); ASSERT_EQ(1, result.size());
auto z = result->at(0); auto z = result.at(0);
// z->printIndexedBuffer("Hello indexreduce2"); // z->printIndexedBuffer("Hello indexreduce2");
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
//ASSERT_EQ(4, z->e<int>(0)); //ASSERT_EQ(4, z->e<int>(0));
@ -362,7 +362,7 @@ TEST_F(LegacyOpsTests, IndexReduceTests_2) {
//ASSERT_EQ(4, z->e<int>(3)); //ASSERT_EQ(4, z->e<int>(3));
//ASSERT_EQ(4, z->e<int>(4)); //ASSERT_EQ(4, z->e<int>(4));
delete result;
} }
TEST_F(LegacyOpsTests, BroadcastingTests_1) { TEST_F(LegacyOpsTests, BroadcastingTests_1) {

View File

@ -39,7 +39,7 @@ TEST_F(ListOperationsTests, BasicTest_Write_1) {
auto result = op.execute(&list, {&x}, {}, {1}); auto result = op.execute(&list, {&x}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
ASSERT_EQ(1, list.elements()); ASSERT_EQ(1, list.elements());
@ -47,8 +47,8 @@ TEST_F(ListOperationsTests, BasicTest_Write_1) {
ASSERT_EQ(2, list.elements()); ASSERT_EQ(2, list.elements());
delete result;
delete result2;
} }
TEST_F(ListOperationsTests, BasicTest_Stack_1) { TEST_F(ListOperationsTests, BasicTest_Stack_1) {
@ -66,15 +66,15 @@ TEST_F(ListOperationsTests, BasicTest_Stack_1) {
auto result = op.execute(&list, {}, {}, {1}); auto result = op.execute(&list, {}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
// z->printShapeInfo(); // z->printShapeInfo();
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
TEST_F(ListOperationsTests, BasicTest_UnStackList_1) { TEST_F(ListOperationsTests, BasicTest_UnStackList_1) {
@ -93,10 +93,10 @@ TEST_F(ListOperationsTests, BasicTest_UnStackList_1) {
auto result = op.execute(&list, {&x}, {}, {0}); auto result = op.execute(&list, {&x}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
ASSERT_EQ(list.elements(), 10); ASSERT_EQ(list.elements(), 10);
// auto z = result->at(0); // auto z = result.at(0);
// z->printShapeInfo("The first of"); // z->printShapeInfo("The first of");
// ASSERT_TRUE(exp.isSameShape(z)); // ASSERT_TRUE(exp.isSameShape(z));
// ASSERT_TRUE(exp.equalsTo(z)); // ASSERT_TRUE(exp.equalsTo(z));
@ -107,7 +107,7 @@ TEST_F(ListOperationsTests, BasicTest_UnStackList_1) {
delete row; delete row;
} }
delete result;
} }
//TEST_F(ListOperationsTests, BasicTest_UnStackList_2) { //TEST_F(ListOperationsTests, BasicTest_UnStackList_2) {
@ -126,20 +126,20 @@ TEST_F(ListOperationsTests, BasicTest_UnStackList_1) {
// //
// auto result = op.execute(nullptr, {&x}, {}, {0}); // auto result = op.execute(nullptr, {&x}, {}, {0});
// //
// ASSERT_EQ(ND4J_STATUS_OK, result->status()); // ASSERT_EQ(ND4J_STATUS_OK, result.status());
// ASSERT_EQ(result->size(), 10); // ASSERT_EQ(result->size(), 10);
// //
// // auto z = result->at(0); // // auto z = result.at(0);
//// z->printShapeInfo("The first of"); //// z->printShapeInfo("The first of");
//// ASSERT_TRUE(exp.isSameShape(z)); //// ASSERT_TRUE(exp.isSameShape(z));
//// ASSERT_TRUE(exp.equalsTo(z)); //// ASSERT_TRUE(exp.equalsTo(z));
// for (int e = 0; e < 10; e++) { // for (int e = 0; e < 10; e++) {
// auto row = result->at(e); // auto row = result.at(e);
// ASSERT_TRUE(row->equalsTo(tads->at(e))); // ASSERT_TRUE(row->equalsTo(tads->at(e)));
// //list.write(e, row); // //list.write(e, row);
// } // }
// //
// delete result; //
// delete tads; // delete tads;
//} //}
@ -160,14 +160,14 @@ TEST_F(ListOperationsTests, BasicTest_Read_1) {
auto result = op.execute(&list, {}, {}, {4}); auto result = op.execute(&list, {}, {}, {4});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
TEST_F(ListOperationsTests, BasicTest_Pick_1) { TEST_F(ListOperationsTests, BasicTest_Pick_1) {
@ -192,14 +192,14 @@ TEST_F(ListOperationsTests, BasicTest_Pick_1) {
sd::ops::pick_list op; sd::ops::pick_list op;
auto result = op.execute(&list, {}, {}, {1, 1, 3, 3}); auto result = op.execute(&list, {}, {}, {1, 1, 3, 3});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
TEST_F(ListOperationsTests, BasicTest_Size_1) { TEST_F(ListOperationsTests, BasicTest_Size_1) {
@ -217,14 +217,14 @@ TEST_F(ListOperationsTests, BasicTest_Size_1) {
auto result = op.execute(&list, {}, {}, {1}); auto result = op.execute(&list, {}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
TEST_F(ListOperationsTests, BasicTest_Create_1) { TEST_F(ListOperationsTests, BasicTest_Create_1) {
@ -235,12 +235,12 @@ TEST_F(ListOperationsTests, BasicTest_Create_1) {
auto result = op.execute(nullptr, {&matrix}, {}, {1, 1}); auto result = op.execute(nullptr, {&matrix}, {}, {1, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
// we return flow as well // we return flow as well
ASSERT_EQ(1, result->size()); ASSERT_EQ(1, result.size());
delete result;
} }
TEST_F(ListOperationsTests, BasicTest_Split_1) { TEST_F(ListOperationsTests, BasicTest_Split_1) {
@ -283,7 +283,7 @@ TEST_F(ListOperationsTests, BasicTest_Split_1) {
sd::ops::split_list op; sd::ops::split_list op;
auto result = op.execute(&list, {&matrix, &lengths}, {}, {}); auto result = op.execute(&list, {&matrix, &lengths}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
ASSERT_EQ(3, list.height()); ASSERT_EQ(3, list.height());
@ -296,7 +296,7 @@ TEST_F(ListOperationsTests, BasicTest_Split_1) {
ASSERT_TRUE(exp2.isSameShape(list.readRaw(2))); ASSERT_TRUE(exp2.isSameShape(list.readRaw(2)));
ASSERT_TRUE(exp2.equalsTo(list.readRaw(2))); ASSERT_TRUE(exp2.equalsTo(list.readRaw(2)));
delete result;
} }
TEST_F(ListOperationsTests, BasicTest_Scatter_1) { TEST_F(ListOperationsTests, BasicTest_Scatter_1) {
@ -319,7 +319,7 @@ TEST_F(ListOperationsTests, BasicTest_Scatter_1) {
sd::ops::scatter_list op; sd::ops::scatter_list op;
auto result = op.execute(&list, {&indices, &matrix, &s}, {}, {}); auto result = op.execute(&list, {&indices, &matrix, &s}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
for (int e = 0; e < 10; e++) { for (int e = 0; e < 10; e++) {
auto row = tads.at(9 - e); auto row = tads.at(9 - e);
@ -329,7 +329,7 @@ TEST_F(ListOperationsTests, BasicTest_Scatter_1) {
ASSERT_TRUE(chunk->equalsTo(row)); ASSERT_TRUE(chunk->equalsTo(row));
} }
delete result;
} }
TEST_F(ListOperationsTests, BasicTest_Clone_1) { TEST_F(ListOperationsTests, BasicTest_Clone_1) {
@ -385,10 +385,10 @@ TEST_F(ListOperationsTests, BasicTest_Gather_1) {
sd::ops::gather_list op; sd::ops::gather_list op;
auto result = op.execute(&list, {&indices}, {}, {}); auto result = op.execute(&list, {&indices}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
ASSERT_EQ(1, result->size()); ASSERT_EQ(1, result.size());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
@ -397,7 +397,7 @@ TEST_F(ListOperationsTests, BasicTest_Gather_1) {
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
TEST_F(ListOperationsTests, GraphTests_Sequential_1) { TEST_F(ListOperationsTests, GraphTests_Sequential_1) {

View File

@ -134,13 +134,11 @@ TEST_F(MultiDataTypeTests, Basic_Test_7) {
sd::ops::add op; sd::ops::add op;
auto result = op.evaluate({&x, &y}); auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_EQ(e, *z); ASSERT_EQ(e, *z);
delete result;
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////

View File

@ -66,7 +66,7 @@ TEST_F(NlpTests, basic_sg_hs_test_1) {
sd::ops::skipgram op; sd::ops::skipgram op;
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true); auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto row0 = syn0({0,1, 0,0}, true); auto row0 = syn0({0,1, 0,0}, true);
auto row1 = syn1({1,2, 0,0}, true); auto row1 = syn1({1,2, 0,0}, true);
@ -74,7 +74,7 @@ TEST_F(NlpTests, basic_sg_hs_test_1) {
ASSERT_EQ(exp0, row0); ASSERT_EQ(exp0, row0);
ASSERT_EQ(exp1, row1); ASSERT_EQ(exp1, row1);
delete result;
} }
TEST_F(NlpTests, basic_sg_hs_test_2) { TEST_F(NlpTests, basic_sg_hs_test_2) {
@ -107,7 +107,7 @@ TEST_F(NlpTests, basic_sg_hs_test_2) {
sd::ops::skipgram op; sd::ops::skipgram op;
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true); auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto row0 = syn0({0,1, 0,0}, true); auto row0 = syn0({0,1, 0,0}, true);
auto row1 = syn1({1,2, 0,0}, true); auto row1 = syn1({1,2, 0,0}, true);
@ -117,7 +117,7 @@ TEST_F(NlpTests, basic_sg_hs_test_2) {
ASSERT_EQ(exp1, row1); ASSERT_EQ(exp1, row1);
ASSERT_EQ(exp2, row2); ASSERT_EQ(exp2, row2);
delete result;
} }
TEST_F(NlpTests, basic_sg_hs_test_3) { TEST_F(NlpTests, basic_sg_hs_test_3) {
@ -159,7 +159,7 @@ TEST_F(NlpTests, basic_sg_hs_test_3) {
sd::ops::skipgram op; sd::ops::skipgram op;
auto result0 = op.evaluate({&target, &ngStarter, &indices0, &codes00, &syn00, &syn10, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true); auto result0 = op.evaluate({&target, &ngStarter, &indices0, &codes00, &syn00, &syn10, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true);
auto result1 = op.evaluate({&target, &ngStarter, &indices1, &codes01, &syn01, &syn11, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true); auto result1 = op.evaluate({&target, &ngStarter, &indices1, &codes01, &syn01, &syn11, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true);
ASSERT_EQ(Status::OK(), result0->status()); ASSERT_EQ(Status::OK(), result0.status());
auto row00 = syn00({0,1, 0,0}, true); auto row00 = syn00({0,1, 0,0}, true);
auto row01 = syn01({0,1, 0,0}, true); auto row01 = syn01({0,1, 0,0}, true);
@ -168,9 +168,6 @@ TEST_F(NlpTests, basic_sg_hs_test_3) {
ASSERT_EQ(row2, row1); ASSERT_EQ(row2, row1);
ASSERT_EQ(row00, row01); ASSERT_EQ(row00, row01);
delete result0;
delete result1;
} }
TEST_F(NlpTests, basic_sg_hs_ns_test_1) { TEST_F(NlpTests, basic_sg_hs_ns_test_1) {
@ -192,9 +189,9 @@ TEST_F(NlpTests, basic_sg_hs_ns_test_1) {
sd::ops::skipgram op; sd::ops::skipgram op;
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {3}, {false}, {}, true); auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {3}, {false}, {}, true);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
delete result;
} }
TEST_F(NlpTests, basic_sg_ns_test_1) { TEST_F(NlpTests, basic_sg_ns_test_1) {
@ -227,14 +224,14 @@ TEST_F(NlpTests, basic_sg_ns_test_1) {
sd::ops::skipgram op; sd::ops::skipgram op;
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {1, 1}, {false}, {}, true); auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {1, 1}, {false}, {}, true);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto row0 = syn0({1,2, 0,0}, true); auto row0 = syn0({1,2, 0,0}, true);
ASSERT_EQ(exp0, row0); ASSERT_EQ(exp0, row0);
ASSERT_FALSE(syn1Neg2.equalsTo(syn1Neg, 1e-6)); ASSERT_FALSE(syn1Neg2.equalsTo(syn1Neg, 1e-6));
delete result;
} }
TEST_F(NlpTests, basic_cb_hs_test_1) { TEST_F(NlpTests, basic_cb_hs_test_1) {
@ -269,7 +266,7 @@ TEST_F(NlpTests, basic_cb_hs_test_1) {
sd::ops::cbow op; sd::ops::cbow op;
auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, {}, true); auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, {}, true);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto row_s0_0 = syn0({0,1, 0,0}, true); auto row_s0_0 = syn0({0,1, 0,0}, true);
auto row_s0_1 = syn0({1,2, 0,0}, true); auto row_s0_1 = syn0({1,2, 0,0}, true);
@ -287,7 +284,7 @@ TEST_F(NlpTests, basic_cb_hs_test_1) {
ASSERT_EQ(exp1, row_s1_5); ASSERT_EQ(exp1, row_s1_5);
ASSERT_EQ(exp2, row_s1_6); ASSERT_EQ(exp2, row_s1_6);
delete result;
} }
TEST_F(NlpTests, basic_cb_ns_test_1) { TEST_F(NlpTests, basic_cb_ns_test_1) {
@ -323,7 +320,7 @@ TEST_F(NlpTests, basic_cb_ns_test_1) {
sd::ops::cbow op; sd::ops::cbow op;
auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {1, 2, 0}, {true}, {}, true); auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {1, 2, 0}, {true}, {}, true);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto row_s0_0 = syn0({0,1, 0,0}, true); auto row_s0_0 = syn0({0,1, 0,0}, true);
auto row_s0_1 = syn0({1,2, 0,0}, true); auto row_s0_1 = syn0({1,2, 0,0}, true);
@ -339,7 +336,7 @@ TEST_F(NlpTests, basic_cb_ns_test_1) {
ASSERT_EQ(exp0, row_s0_2); ASSERT_EQ(exp0, row_s0_2);
ASSERT_EQ(exp2, row_s1_6); ASSERT_EQ(exp2, row_s1_6);
delete result;
} }
TEST_F(NlpTests, test_sg_hs_batch_1) { TEST_F(NlpTests, test_sg_hs_batch_1) {
@ -372,7 +369,7 @@ TEST_F(NlpTests, test_sg_hs_batch_1) {
sd::ops::skipgram op; sd::ops::skipgram op;
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false, true}, {}, true); auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false, true}, {}, true);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto row0 = syn0({0,1, 0,0}, true); auto row0 = syn0({0,1, 0,0}, true);
auto row1 = syn1({1,2, 0,0}, true); auto row1 = syn1({1,2, 0,0}, true);
@ -382,7 +379,7 @@ TEST_F(NlpTests, test_sg_hs_batch_1) {
ASSERT_TRUE(exp1.equalsTo(row1, 1e-6)); ASSERT_TRUE(exp1.equalsTo(row1, 1e-6));
ASSERT_TRUE(exp2.equalsTo(row2, 1e-6)); ASSERT_TRUE(exp2.equalsTo(row2, 1e-6));
delete result;
} }
TEST_F(NlpTests, test_sg_ns_batch_1) { TEST_F(NlpTests, test_sg_ns_batch_1) {
@ -416,9 +413,9 @@ TEST_F(NlpTests, test_sg_ns_batch_1) {
sd::ops::skipgram op; sd::ops::skipgram op;
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {4, 5}, {false, true}, {}, true); auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {4, 5}, {false, true}, {}, true);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
delete result;
} }
TEST_F(NlpTests, test_cbow_hs_batch_1) { TEST_F(NlpTests, test_cbow_hs_batch_1) {
@ -449,7 +446,7 @@ TEST_F(NlpTests, test_cbow_hs_batch_1) {
sd::ops::cbow op; sd::ops::cbow op;
auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, {}, true); auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, {}, true);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto exp0 = NDArrayFactory::create<float>('c', {1, 10}); auto exp0 = NDArrayFactory::create<float>('c', {1, 10});
auto exp1 = NDArrayFactory::create<float>('c', {1, 10}); auto exp1 = NDArrayFactory::create<float>('c', {1, 10});
@ -474,5 +471,4 @@ TEST_F(NlpTests, test_cbow_hs_batch_1) {
ASSERT_EQ(exp1, row_s1_5); ASSERT_EQ(exp1, row_s1_5);
ASSERT_EQ(exp2, row_s1_6); ASSERT_EQ(exp2, row_s1_6);
delete result;
} }

File diff suppressed because it is too large Load Diff

View File

@ -260,9 +260,9 @@ TEST_F(RNGTests, Test_Gaussian_21) {
auto result = op.evaluate({&x0}, {}, {}); auto result = op.evaluate({&x0}, {}, {});
//x0.printIndexedBuffer("X0 Normal"); //x0.printIndexedBuffer("X0 Normal");
//x1.printIndexedBuffer("X1 Normal"); //x1.printIndexedBuffer("X1 Normal");
ASSERT_TRUE(result->status() == Status::OK()); ASSERT_TRUE(result.status() == Status::OK());
auto mean = result->at(0); auto mean = result.at(0);
auto variance = result->at(1); auto variance = result.at(1);
// mean->printIndexedBuffer("Mean"); // mean->printIndexedBuffer("Mean");
// variance->printIndexedBuffer("Variance"); // variance->printIndexedBuffer("Variance");
@ -270,7 +270,7 @@ TEST_F(RNGTests, Test_Gaussian_21) {
ASSERT_NEAR(sd::math::nd4j_abs(mean->e<float>(0)), 0.f, 0.2f); ASSERT_NEAR(sd::math::nd4j_abs(mean->e<float>(0)), 0.f, 0.2f);
ASSERT_NEAR(variance->e<float>(0), 1.0f, 0.2f); ASSERT_NEAR(variance->e<float>(0), 1.0f, 0.2f);
delete result;
} }
#ifdef DEBUG_BUILD #ifdef DEBUG_BUILD
@ -292,15 +292,15 @@ TEST_F(RNGTests, Test_Gaussian_22) {
auto result = op.evaluate({&x0}, {}, {}); auto result = op.evaluate({&x0}, {}, {});
//x0.printIndexedBuffer("X0 Normal"); //x0.printIndexedBuffer("X0 Normal");
//x1.printIndexedBuffer("X1 Normal"); //x1.printIndexedBuffer("X1 Normal");
ASSERT_TRUE(result->status() == Status::OK()); ASSERT_TRUE(result.status() == Status::OK());
auto mean0 = result->at(0); auto mean0 = result.at(0);
auto variance0 = result->at(1); auto variance0 = result.at(1);
//mean0->printIndexedBuffer("Mean"); //mean0->printIndexedBuffer("Mean");
//variance0->printIndexedBuffer("Variance"); //variance0->printIndexedBuffer("Variance");
ASSERT_NEAR(sd::math::nd4j_abs(mean0->e<float>(0)), 0.f, 1.0e-3f); ASSERT_NEAR(sd::math::nd4j_abs(mean0->e<float>(0)), 0.f, 1.0e-3f);
ASSERT_NEAR(variance0->e<float>(0), 1.0f, 1.e-3f); ASSERT_NEAR(variance0->e<float>(0), 1.0f, 1.e-3f);
delete result;
} }
TEST_F(RNGTests, Test_Gaussian_3) { TEST_F(RNGTests, Test_Gaussian_3) {
@ -413,18 +413,17 @@ TEST_F(RNGTests, Test_Truncated_21) {
ASSERT_NEAR(deviation.e<float>(0), 2.f, 0.5); ASSERT_NEAR(deviation.e<float>(0), 2.f, 0.5);
sd::ops::moments op; sd::ops::moments op;
auto result = op.evaluate({&x0}, {}, {}, {}, {}, false); auto result = op.evaluate({&x0}, {}, {}, {}, {}, false);
// result->at(0)->printBuffer("MEAN");
// result->at(1)->printBuffer("VARIANCE"); // result.at(0)->printBuffer("MEAN");
delete result; // result.at(1)->printBuffer("VARIANCE");
sd::ops::reduce_min minOp; sd::ops::reduce_min minOp;
sd::ops::reduce_max maxOp; sd::ops::reduce_max maxOp;
auto minRes = minOp.evaluate({&x1}, {}, {}, {}); auto minRes = minOp.evaluate({&x1}, {}, {}, {});
auto maxRes = maxOp.evaluate({&x0}, {}, {}, {}); auto maxRes = maxOp.evaluate({&x0}, {}, {}, {});
// minRes->at(0)->printBuffer("MIN for Truncated"); // minRes->at(0)->printBuffer("MIN for Truncated");
// maxRes->at(0)->printBuffer("MAX for Truncated"); // maxRes->at(0)->printBuffer("MAX for Truncated");
delete minRes;
delete maxRes;
} }
TEST_F(RNGTests, Test_Truncated_22) { TEST_F(RNGTests, Test_Truncated_22) {
@ -460,18 +459,15 @@ TEST_F(RNGTests, Test_Truncated_22) {
ASSERT_NEAR(deviation.e<float>(0), 4.f, 0.52); ASSERT_NEAR(deviation.e<float>(0), 4.f, 0.52);
sd::ops::moments op; sd::ops::moments op;
auto result = op.evaluate({&x0}, {}, {}, {}, {}, false); auto result = op.evaluate({&x0}, {}, {}, {}, {}, false);
// result->at(0)->printBuffer("MEAN");
// result->at(1)->printBuffer("VARIANCE");
delete result;
sd::ops::reduce_min minOp; sd::ops::reduce_min minOp;
sd::ops::reduce_max maxOp; sd::ops::reduce_max maxOp;
auto minRes = minOp.evaluate({&x1}, {}, {}, {}); auto minRes = minOp.evaluate({&x1}, {}, {}, {});
auto maxRes = maxOp.evaluate({&x0}, {}, {}, {}); auto maxRes = maxOp.evaluate({&x0}, {}, {}, {});
// minRes->at(0)->printBuffer("MIN for Truncated2"); // minRes->at(0)->printBuffer("MIN for Truncated2");
// maxRes->at(0)->printBuffer("MAX for Truncated2"); // maxRes->at(0)->printBuffer("MAX for Truncated2");
delete minRes;
delete maxRes;
} }
TEST_F(RNGTests, Test_Truncated_23) { TEST_F(RNGTests, Test_Truncated_23) {
@ -509,16 +505,14 @@ TEST_F(RNGTests, Test_Truncated_23) {
auto result = op.evaluate({&x0}); auto result = op.evaluate({&x0});
// result->at(0)->printBuffer("MEAN"); // result->at(0)->printBuffer("MEAN");
// result->at(1)->printBuffer("VARIANCE"); // result->at(1)->printBuffer("VARIANCE");
delete result;
sd::ops::reduce_min minOp; sd::ops::reduce_min minOp;
sd::ops::reduce_max maxOp; sd::ops::reduce_max maxOp;
auto minRes = minOp.evaluate({&x1}, {}, {}, {}); auto minRes = minOp.evaluate({&x1}, {}, {}, {});
auto maxRes = maxOp.evaluate({&x0}, {}, {}, {}); auto maxRes = maxOp.evaluate({&x0}, {}, {}, {});
// minRes->at(0)->printBuffer("MIN for Truncated3"); // minRes->at(0)->printBuffer("MIN for Truncated3");
// maxRes->at(0)->printBuffer("MAX for Truncated3"); // maxRes->at(0)->printBuffer("MAX for Truncated3");
delete minRes;
delete maxRes;
} }
TEST_F(RNGTests, Test_Truncated_3) { TEST_F(RNGTests, Test_Truncated_3) {
@ -568,15 +562,15 @@ TEST_F(RNGTests, Test_Uniform_2) {
auto op = new sd::ops::LegacyRandomOp(0); auto op = new sd::ops::LegacyRandomOp(0);
auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {}); auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(x1.isSameShape(z)); ASSERT_TRUE(x1.isSameShape(z));
ASSERT_TRUE(x1.equalsTo(z)); ASSERT_TRUE(x1.equalsTo(z));
delete op; delete op;
delete result;
} }
TEST_F(RNGTests, Test_Gaussian_2) { TEST_F(RNGTests, Test_Gaussian_2) {
@ -588,15 +582,15 @@ TEST_F(RNGTests, Test_Gaussian_2) {
auto op = new sd::ops::LegacyRandomOp(random::GaussianDistribution); auto op = new sd::ops::LegacyRandomOp(random::GaussianDistribution);
auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {}); auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(x1.isSameShape(z)); ASSERT_TRUE(x1.isSameShape(z));
ASSERT_TRUE(x1.equalsTo(z)); ASSERT_TRUE(x1.equalsTo(z));
delete op; delete op;
delete result;
} }
TEST_F(RNGTests, Test_LogNorm_2) { TEST_F(RNGTests, Test_LogNorm_2) {
@ -608,15 +602,15 @@ TEST_F(RNGTests, Test_LogNorm_2) {
auto op = new sd::ops::LegacyRandomOp(random::LogNormalDistribution); auto op = new sd::ops::LegacyRandomOp(random::LogNormalDistribution);
auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {}); auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(x1.isSameShape(z)); ASSERT_TRUE(x1.isSameShape(z));
ASSERT_TRUE(x1.equalsTo(z)); ASSERT_TRUE(x1.equalsTo(z));
delete op; delete op;
delete result;
} }
TEST_F(RNGTests, Test_TruncatedNorm_2) { TEST_F(RNGTests, Test_TruncatedNorm_2) {
@ -628,14 +622,14 @@ TEST_F(RNGTests, Test_TruncatedNorm_2) {
auto op = new sd::ops::LegacyRandomOp(random::TruncatedNormalDistribution); auto op = new sd::ops::LegacyRandomOp(random::TruncatedNormalDistribution);
auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {}); auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(x1.isSameShape(z)); ASSERT_TRUE(x1.isSameShape(z));
ASSERT_TRUE(x1.equalsTo(z)); ASSERT_TRUE(x1.equalsTo(z));
delete op; delete op;
delete result;
} }
@ -648,15 +642,15 @@ TEST_F(RNGTests, Test_Binomial_2) {
auto op = new sd::ops::LegacyRandomOp(random::BinomialDistributionEx); auto op = new sd::ops::LegacyRandomOp(random::BinomialDistributionEx);
auto result = op->execute(_rngA, {&input}, {0.5f}, {3}); auto result = op->execute(_rngA, {&input}, {0.5f}, {3});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(x1.isSameShape(z)); ASSERT_TRUE(x1.isSameShape(z));
ASSERT_TRUE(x1.equalsTo(z)); ASSERT_TRUE(x1.equalsTo(z));
delete op; delete op;
delete result;
} }
@ -669,15 +663,15 @@ TEST_F(RNGTests, Test_Bernoulli_2) {
auto op = new sd::ops::LegacyRandomOp(random::BernoulliDistribution); auto op = new sd::ops::LegacyRandomOp(random::BernoulliDistribution);
auto result = op->execute(_rngA, {&input}, {0.5f}, {}); auto result = op->execute(_rngA, {&input}, {0.5f}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(x1.isSameShape(z)); ASSERT_TRUE(x1.isSameShape(z));
ASSERT_TRUE(x1.equalsTo(z)); ASSERT_TRUE(x1.equalsTo(z));
delete op; delete op;
delete result;
} }
TEST_F(RNGTests, Test_GaussianDistribution_1) { TEST_F(RNGTests, Test_GaussianDistribution_1) {
@ -687,9 +681,9 @@ TEST_F(RNGTests, Test_GaussianDistribution_1) {
sd::ops::random_normal op; sd::ops::random_normal op;
auto result = op.evaluate({&x}, {0.0, 1.0f}, {}); auto result = op.evaluate({&x}, {0.0, 1.0f}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp0.isSameShape(z)); ASSERT_TRUE(exp0.isSameShape(z));
ASSERT_FALSE(exp0.equalsTo(z)); ASSERT_FALSE(exp0.equalsTo(z));
@ -698,7 +692,7 @@ TEST_F(RNGTests, Test_GaussianDistribution_1) {
ASSERT_FALSE(nexp1->equalsTo(z)); ASSERT_FALSE(nexp1->equalsTo(z));
ASSERT_FALSE(nexp2->equalsTo(z)); ASSERT_FALSE(nexp2->equalsTo(z));
delete result;
} }
TEST_F(RNGTests, Test_BernoulliDistribution_1) { TEST_F(RNGTests, Test_BernoulliDistribution_1) {
@ -708,9 +702,9 @@ TEST_F(RNGTests, Test_BernoulliDistribution_1) {
sd::ops::random_bernoulli op; sd::ops::random_bernoulli op;
auto result = op.evaluate({&x}, {0.5f}, {}); auto result = op.evaluate({&x}, {0.5f}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_FALSE(exp0.equalsTo(z)); ASSERT_FALSE(exp0.equalsTo(z));
@ -718,7 +712,7 @@ TEST_F(RNGTests, Test_BernoulliDistribution_1) {
ASSERT_FALSE(nexp1->equalsTo(z)); ASSERT_FALSE(nexp1->equalsTo(z));
ASSERT_FALSE(nexp2->equalsTo(z)); ASSERT_FALSE(nexp2->equalsTo(z));
delete result;
} }
@ -729,9 +723,9 @@ TEST_F(RNGTests, Test_ExponentialDistribution_1) {
sd::ops::random_exponential op; sd::ops::random_exponential op;
auto result = op.evaluate({&x}, {0.25f}, {0}); auto result = op.evaluate({&x}, {0.25f}, {0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp0.isSameShape(z)); ASSERT_TRUE(exp0.isSameShape(z));
ASSERT_FALSE(exp0.equalsTo(z)); ASSERT_FALSE(exp0.equalsTo(z));
@ -740,7 +734,7 @@ TEST_F(RNGTests, Test_ExponentialDistribution_1) {
ASSERT_FALSE(nexp1->equalsTo(z)); ASSERT_FALSE(nexp1->equalsTo(z));
ASSERT_FALSE(nexp2->equalsTo(z)); ASSERT_FALSE(nexp2->equalsTo(z));
delete result;
} }
TEST_F(RNGTests, Test_ExponentialDistribution_2) { TEST_F(RNGTests, Test_ExponentialDistribution_2) {
@ -753,9 +747,9 @@ TEST_F(RNGTests, Test_ExponentialDistribution_2) {
sd::ops::random_exponential op; sd::ops::random_exponential op;
auto result = op.evaluate({&x, &y}, {0.25f}, {0}); auto result = op.evaluate({&x, &y}, {0.25f}, {0});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp0.isSameShape(z)); ASSERT_TRUE(exp0.isSameShape(z));
ASSERT_FALSE(exp0.equalsTo(z)); ASSERT_FALSE(exp0.equalsTo(z));
@ -764,7 +758,7 @@ TEST_F(RNGTests, Test_ExponentialDistribution_2) {
ASSERT_FALSE(nexp1->equalsTo(z)); ASSERT_FALSE(nexp1->equalsTo(z));
ASSERT_FALSE(nexp2->equalsTo(z)); ASSERT_FALSE(nexp2->equalsTo(z));
delete result;
} }
TEST_F(RNGTests, Test_PoissonDistribution_1) { TEST_F(RNGTests, Test_PoissonDistribution_1) {
@ -777,14 +771,14 @@ TEST_F(RNGTests, Test_PoissonDistribution_1) {
sd::ops::random_poisson op; sd::ops::random_poisson op;
auto result = op.evaluate({&x, &la}, {}, {}); auto result = op.evaluate({&x, &la}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
// z->printIndexedBuffer("Poisson distribution"); // z->printIndexedBuffer("Poisson distribution");
ASSERT_TRUE(exp0.isSameShape(z)); ASSERT_TRUE(exp0.isSameShape(z));
ASSERT_FALSE(exp0.equalsTo(z)); ASSERT_FALSE(exp0.equalsTo(z));
delete result;
} }
TEST_F(RNGTests, Test_GammaDistribution_1) { TEST_F(RNGTests, Test_GammaDistribution_1) {
@ -797,14 +791,14 @@ TEST_F(RNGTests, Test_GammaDistribution_1) {
sd::ops::random_gamma op; sd::ops::random_gamma op;
auto result = op.evaluate({&x, &al}, {}, {}); auto result = op.evaluate({&x, &al}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
// z->printIndexedBuffer("Gamma distribution"); // z->printIndexedBuffer("Gamma distribution");
ASSERT_TRUE(exp0.isSameShape(z)); ASSERT_TRUE(exp0.isSameShape(z));
ASSERT_FALSE(exp0.equalsTo(z)); ASSERT_FALSE(exp0.equalsTo(z));
delete result;
} }
TEST_F(RNGTests, Test_GammaDistribution_2) { TEST_F(RNGTests, Test_GammaDistribution_2) {
@ -818,14 +812,14 @@ TEST_F(RNGTests, Test_GammaDistribution_2) {
sd::ops::random_gamma op; sd::ops::random_gamma op;
auto result = op.evaluate({&x, &al, &be}, {}, {}); auto result = op.evaluate({&x, &al, &be}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
// z->printIndexedBuffer("Gamma distribution"); // z->printIndexedBuffer("Gamma distribution");
ASSERT_TRUE(exp0.isSameShape(z)); ASSERT_TRUE(exp0.isSameShape(z));
ASSERT_FALSE(exp0.equalsTo(z)); ASSERT_FALSE(exp0.equalsTo(z));
delete result;
} }
TEST_F(RNGTests, Test_GammaDistribution_3) { TEST_F(RNGTests, Test_GammaDistribution_3) {
@ -839,14 +833,14 @@ TEST_F(RNGTests, Test_GammaDistribution_3) {
sd::ops::random_gamma op; sd::ops::random_gamma op;
auto result = op.evaluate({&x, &al, &be}, {}, {}); auto result = op.evaluate({&x, &al, &be}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
// z->printIndexedBuffer("Gamma distribution"); // z->printIndexedBuffer("Gamma distribution");
ASSERT_TRUE(exp0.isSameShape(z)); ASSERT_TRUE(exp0.isSameShape(z));
ASSERT_FALSE(exp0.equalsTo(z)); ASSERT_FALSE(exp0.equalsTo(z));
delete result;
} }
TEST_F(RNGTests, Test_UniformDistribution_04) { TEST_F(RNGTests, Test_UniformDistribution_04) {
@ -858,13 +852,13 @@ TEST_F(RNGTests, Test_UniformDistribution_04) {
sd::ops::randomuniform op; sd::ops::randomuniform op;
auto result = op.evaluate({&x, &al, &be}, {}, {DataType::INT32}); auto result = op.evaluate({&x, &al, &be}, {}, {DataType::INT32});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp0.isSameShape(z)); ASSERT_TRUE(exp0.isSameShape(z));
ASSERT_FALSE(exp0.equalsTo(z)); ASSERT_FALSE(exp0.equalsTo(z));
delete result;
} }
namespace sd { namespace sd {
@ -1021,12 +1015,12 @@ TEST_F(RNGTests, test_multinomial_1) {
NDArray expectedZ('c', { 3, 3 }, { 0., 0, 0, 0, 0, 0, 0, 0, 0 }, sd::DataType::INT64); NDArray expectedZ('c', { 3, 3 }, { 0., 0, 0, 0, 0, 0, 0, 0, 0 }, sd::DataType::INT64);
auto result = op.evaluate({ &probsZ, &samples }, { }, { 1, INT64 }); auto result = op.evaluate({ &probsZ, &samples }, { }, { 1, INT64 });
auto outputZ = result->at(0); auto outputZ = result.at(0);
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
ASSERT_TRUE(expectedZ.isSameShape(outputZ)); ASSERT_TRUE(expectedZ.isSameShape(outputZ));
ASSERT_TRUE(expectedZ.equalsTo(outputZ)); ASSERT_TRUE(expectedZ.equalsTo(outputZ));
delete result;
} }
TEST_F(RNGTests, test_multinomial_2) { TEST_F(RNGTests, test_multinomial_2) {
@ -1117,8 +1111,8 @@ TEST_F(RNGTests, test_multinomial_5) {
} }
auto resultR = op.evaluate({ &probs, &samples }, { }, { 1 }); auto resultR = op.evaluate({ &probs, &samples }, { }, { 1 });
auto outputR = resultR->at(0); auto outputR = resultR.at(0);
ASSERT_EQ(Status::OK(), resultR->status()); ASSERT_EQ(Status::OK(), resultR.status());
deviation = outputR->varianceNumber(variance::SummaryStatsStandardDeviation, false); deviation = outputR->varianceNumber(variance::SummaryStatsStandardDeviation, false);
mean = outputR->meanNumber(); mean = outputR->meanNumber();
@ -1131,7 +1125,6 @@ TEST_F(RNGTests, test_multinomial_5) {
ASSERT_TRUE(value >= 0 && value < ClassValue); ASSERT_TRUE(value >= 0 && value < ClassValue);
} }
delete resultR;
} }
@ -1150,8 +1143,8 @@ TEST_F(RNGTests, test_multinomial_6) {
NDArray probsR('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, sd::DataType::FLOAT32); NDArray probsR('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, sd::DataType::FLOAT32);
auto resultR = op.evaluate({ &probsR, &samples }, { }, { 0 }); auto resultR = op.evaluate({ &probsR, &samples }, { }, { 0 });
auto outputR = resultR->at(0); auto outputR = resultR.at(0);
ASSERT_EQ(Status::OK(), resultR->status()); ASSERT_EQ(Status::OK(), resultR.status());
NDArray countsR('c', { ClassValue }, { 0., 0, 0, 0, 0 }, sd::DataType::DOUBLE); NDArray countsR('c', { ClassValue }, { 0., 0, 0, 0, 0 }, sd::DataType::DOUBLE);
@ -1175,7 +1168,7 @@ TEST_F(RNGTests, test_multinomial_6) {
ASSERT_NEAR(1.2175, deviation.e<double>(0), 45e-3); // 1000000 35e-3); ASSERT_NEAR(1.2175, deviation.e<double>(0), 45e-3); // 1000000 35e-3);
ASSERT_NEAR(2.906, mean.e<double>(0), 45e-3); // 1000000 35e-3); ASSERT_NEAR(2.906, mean.e<double>(0), 45e-3); // 1000000 35e-3);
delete resultR;
RandomGenerator rng(1234, 1234); RandomGenerator rng(1234, 1234);
NDArray probs('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, sd::DataType::FLOAT32); NDArray probs('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, sd::DataType::FLOAT32);

View File

@ -96,14 +96,14 @@ TEST_F(ScalarTests, Test_Concat_1) {
sd::ops::concat op; sd::ops::concat op;
auto result = op.evaluate({&t, &u, &v}, {}, {0}); auto result = op.evaluate({&t, &u, &v}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
@ -116,15 +116,15 @@ TEST_F(ScalarTests, Test_Concat_2) {
sd::ops::concat op; sd::ops::concat op;
auto result = op.evaluate({&t, &u, &v}, {}, {0}); auto result = op.evaluate({&t, &u, &v}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
// z->printIndexedBuffer(); // z->printIndexedBuffer();
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
@ -137,16 +137,16 @@ TEST_F(ScalarTests, Test_Concat_3) {
sd::ops::concat op; sd::ops::concat op;
auto result = op.evaluate({&t, &u, &v}, {}, {0}); auto result = op.evaluate({&t, &u, &v}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
//z->printShapeInfo("z"); //z->printShapeInfo("z");
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
TEST_F(ScalarTests, Test_ExpandDims_1) { TEST_F(ScalarTests, Test_ExpandDims_1) {
@ -156,14 +156,14 @@ TEST_F(ScalarTests, Test_ExpandDims_1) {
sd::ops::expand_dims op; sd::ops::expand_dims op;
auto result = op.evaluate({&x}, {}, {0}); auto result = op.evaluate({&x}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
TEST_F(ScalarTests, Test_Squeeze_1) { TEST_F(ScalarTests, Test_Squeeze_1) {
@ -172,14 +172,14 @@ TEST_F(ScalarTests, Test_Squeeze_1) {
sd::ops::squeeze op; sd::ops::squeeze op;
auto result = op.evaluate({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
@ -189,14 +189,14 @@ TEST_F(ScalarTests, Test_Reshape_1) {
sd::ops::reshape op; sd::ops::reshape op;
auto result = op.evaluate({&x}, {}, {-99, 1, 1, 1}); auto result = op.evaluate({&x}, {}, {-99, 1, 1, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
@ -206,14 +206,14 @@ TEST_F(ScalarTests, Test_Permute_1) {
sd::ops::permute op; sd::ops::permute op;
auto result = op.evaluate({&x}, {}, {0}); auto result = op.evaluate({&x}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
TEST_F(ScalarTests, Test_Concat_Scalar_1) { TEST_F(ScalarTests, Test_Concat_Scalar_1) {
@ -225,14 +225,13 @@ TEST_F(ScalarTests, Test_Concat_Scalar_1) {
sd::ops::concat op; sd::ops::concat op;
auto result = op.evaluate({&t, &u, &v, &w}, {}, {0}); auto result = op.evaluate({&t, &u, &v, &w}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
@ -245,12 +244,11 @@ TEST_F(ScalarTests, Test_Concat_Scalar_2) {
sd::ops::concat op; sd::ops::concat op;
auto result = op.evaluate({&t, &u, &v, &w}, {}, {1}); auto result = op.evaluate({&t, &u, &v, &w}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }

View File

@ -308,14 +308,13 @@ TEST_F(ShapeTests, Tests_Transpose_119_2) {
sd::ops::transpose op; sd::ops::transpose op;
auto result = op.evaluate({&x}); auto result = op.evaluate({&x});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
TEST_F(ShapeTests, Tests_Transpose_119_3) { TEST_F(ShapeTests, Tests_Transpose_119_3) {

View File

@ -70,14 +70,14 @@ TEST_F(SingleDimTests, Test_Concat_1) {
sd::ops::concat op; sd::ops::concat op;
auto result = op.evaluate({&x, &y}, {}, {0}); auto result = op.evaluate({&x, &y}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
TEST_F(SingleDimTests, Test_Reduce_1) { TEST_F(SingleDimTests, Test_Reduce_1) {
@ -104,14 +104,14 @@ TEST_F(SingleDimTests, Test_ExpandDims_1) {
sd::ops::expand_dims op; sd::ops::expand_dims op;
auto result = op.evaluate({&x}, {}, {0}); auto result = op.evaluate({&x}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
@ -122,14 +122,14 @@ TEST_F(SingleDimTests, Test_ExpandDims_2) {
sd::ops::expand_dims op; sd::ops::expand_dims op;
auto result = op.evaluate({&x}, {}, {1}); auto result = op.evaluate({&x}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
@ -142,14 +142,14 @@ TEST_F(SingleDimTests, Test_Squeeze_1) {
sd::ops::squeeze op; sd::ops::squeeze op;
auto result = op.evaluate({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_EQ(exp.rankOf(), z->rankOf()); ASSERT_EQ(exp.rankOf(), z->rankOf());
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
TEST_F(SingleDimTests, Test_Squeeze_2) { TEST_F(SingleDimTests, Test_Squeeze_2) {
@ -158,14 +158,14 @@ TEST_F(SingleDimTests, Test_Squeeze_2) {
sd::ops::squeeze op; sd::ops::squeeze op;
auto result = op.evaluate({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
TEST_F(SingleDimTests, Test_Reshape_1) { TEST_F(SingleDimTests, Test_Reshape_1) {
@ -174,14 +174,14 @@ TEST_F(SingleDimTests, Test_Reshape_1) {
sd::ops::reshape op; sd::ops::reshape op;
auto result = op.evaluate({&x}, {}, {-99, 3}); auto result = op.evaluate({&x}, {}, {-99, 3});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
TEST_F(SingleDimTests, Test_Reshape_2) { TEST_F(SingleDimTests, Test_Reshape_2) {
@ -190,14 +190,14 @@ TEST_F(SingleDimTests, Test_Reshape_2) {
sd::ops::reshape op; sd::ops::reshape op;
auto result = op.evaluate({&x}, {}, {-99, 1, 3}); auto result = op.evaluate({&x}, {}, {-99, 1, 3});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }
@ -207,12 +207,12 @@ TEST_F(SingleDimTests, Test_Permute_1) {
sd::ops::permute op; sd::ops::permute op;
auto result = op.evaluate({&x}, {}, {0}); auto result = op.evaluate({&x}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
delete result;
} }