Improve ResultSet usage in libnd4j (#281)

* libnd4j profiling DeclarableOp and Tests by replacing return ResultSet pointer by instance

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j profiling semantic change in tests cases

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j some corrections to make new ResultSet semantic works, fixed one test

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j more tests fixes

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* - correct copy and move assignment operators of ResultSet class

Signed-off-by: Yurii <iuriish@yahoo.com>

Co-authored-by: Yurii <iuriish@yahoo.com>
Co-authored-by: raver119 <raver119@gmail.com>
master
Oleh 2020-03-10 06:42:50 +02:00 committed by GitHub
parent 57210b936c
commit c3223dbc7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
55 changed files with 6726 additions and 7658 deletions

View File

@ -40,6 +40,8 @@ namespace sd {
Nd4jStatus _status = ND4J_STATUS_OK;
bool _removable = true;
void delContent();
public:
explicit ResultSet();

View File

@ -160,9 +160,7 @@ namespace sd {
auto result = op.evaluate(inputs);
auto array = new NDArray(result->at(0)->dup());
delete result;
auto array = new NDArray(result.at(0)->dup());
return array;
}

View File

@ -77,15 +77,16 @@ namespace sd {
}
////////////////////////////////////////////////////////////////////////
// move assignment operator
// move assignment operator
ResultSet& ResultSet::operator=(ResultSet&& other) noexcept {
if (this == &other)
return *this;
this->~ResultSet();
delContent();
_content = std::move(other._content);
_status = other._status;
_removable = other._removable;
other._removable = false;
@ -98,10 +99,10 @@ namespace sd {
if (this == &other)
return *this;
this->~ResultSet();
delContent();
for (const auto v:other._content)
_content.emplace_back(v);
for (const auto v : other._content)
_content.push_back(v);
_status = other._status;
_removable = false;
@ -109,11 +110,15 @@ namespace sd {
return *this;
}
void ResultSet::delContent() {
if (_removable)
for (auto v : _content)
delete v;
}
ResultSet::~ResultSet() {
if (_removable)
for (auto v: _content)
delete v;
delContent();
}
void ResultSet::setNonRemovable() {

View File

@ -60,7 +60,7 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
fillGradArrays(loss, std::vector<NDArray*>(&inArrsBP[numInArrsFF], &inArrsBP[numInArrsFF + numInGradArrsBP]));
// back prop pass
ResultSet* outArrsBP = opBP.execute(argsHolderBP); // number of output arrays in back prop = numInArrsFF;
ResultSet outArrsBP = opBP.execute(argsHolderBP); // number of output arrays in back prop = numInArrsFF;
NDArray tmpScalar(sd::DataType::DOUBLE, inArrsFF[0]->getContext()); // scalar = 0
@ -78,18 +78,17 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
// add epsilon, feed forward
inArrsFF[i]->p<double>(j, orig + EPSILON);
ResultSet* outArrsFF = opFF.execute(argsHolderFF);
int numOutArrs = outArrsFF->size();
ResultSet outArrsFF = opFF.execute(argsHolderFF);
int numOutArrs = outArrsFF.size();
double scorePlus = 0.;
for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays
if(loss == SUM)
outArrsFF->at(k)->reduceNumber(reduce::Sum, tmpScalar);
outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar);
else
outArrsFF->at(k)->reduceNumber(reduce::Mean, tmpScalar);
outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar);
scorePlus += tmpScalar.e<double>(0);
}
delete outArrsFF;
// subtract epsilon, feed forward
inArrsFF[i]->p<double>(j, orig - EPSILON);
@ -98,12 +97,11 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays
if(loss == SUM)
outArrsFF->at(k)->reduceNumber(reduce::Sum, tmpScalar);
outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar);
else
outArrsFF->at(k)->reduceNumber(reduce::Mean, tmpScalar);
outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar);
scoreMinus += tmpScalar.e<double>(0);
}
delete outArrsFF;
// restore initial element value
inArrsFF[i]->p<double>(j, orig);
@ -116,7 +114,7 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
}
// get analytical gradient
const double analyticGrad = outArrsBP->at(i)->e<double>(j);
const double analyticGrad = outArrsBP.at(i)->e<double>(j);
if(std::isnan(analyticGrad) || std::isinf(analyticGrad)) {
printf("GradCheck::checkGrad: got wrong value for analytical gradient for input array # %i and its element at position %lld ! \n", i, j);
throw std::runtime_error("");
@ -138,13 +136,11 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
continue;
printf("numericalGrad = %f, analyticGrad = %f \n", numericalGrad, analyticGrad);
printf("GradCheck::checkGrad: got RELERROR = %f > MAXRELERROR(%f) for input array # %i and its element at position %lld ! \n", relError, MAXRELERR, i, j);
delete outArrsBP;
return false;
}
}
}
delete outArrsBP;
return true;
}

View File

@ -45,8 +45,8 @@ namespace sd {
Nd4jStatus execute(Context* block) override;
ResultSet* execute(NDArrayList* list, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs);
ResultSet* execute(NDArrayList* list, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs);
ResultSet execute(NDArrayList* list, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs);
ResultSet execute(NDArrayList* list, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs);
ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override;
};

View File

@ -176,17 +176,17 @@ namespace sd {
Nd4jStatus execute(const std::vector<NDArray*> &inputs, const std::vector<NDArray*> &outputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs = std::vector<bool>(), const std::vector<sd::DataType> &dArgs = std::vector<sd::DataType>(), bool isInplace = false);
sd::ResultSet* evaluate(const std::vector<NDArray*> &inputs);
sd::ResultSet evaluate(const std::vector<NDArray*> &inputs);
template <class T, typename = std::enable_if<DataTypeUtils::scalarTypesForExecution<T>::value>>
sd::ResultSet* evaluate(const std::vector<NDArray*> &inputs, std::initializer_list<T> args);
sd::ResultSet evaluate(const std::vector<NDArray*> &inputs, std::initializer_list<T> args);
sd::ResultSet* evaluate(const std::vector<NDArray*> &inputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs = std::vector<bool>(), const std::vector<sd::DataType> &dArgs = std::vector<sd::DataType>(), bool isInplace = false);
sd::ResultSet evaluate(const std::vector<NDArray*> &inputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs = std::vector<bool>(), const std::vector<sd::DataType> &dArgs = std::vector<sd::DataType>(), bool isInplace = false);
Nd4jStatus execute(sd::graph::RandomGenerator& rng, const std::vector<NDArray*>& inputs, const std::vector<NDArray*>& outputs, const std::vector<double>& tArgs, const std::vector<Nd4jLong>& iArgs, const std::vector<bool>& bArgs, const std::vector<sd::DataType> &dArgs = std::vector<sd::DataType>(), bool isInplace = false, sd::DataType type = sd::DataType::FLOAT32);
sd::ResultSet* execute(const sd::OpArgsHolder& holder, bool isInplace = false);
sd::ResultSet execute(const sd::OpArgsHolder& holder, bool isInplace = false);
// There methods provide various validation options
Nd4jStatus validateNonEmptyInput(Context& block);

View File

@ -41,8 +41,9 @@ namespace sd {
template <typename T>
Nd4jStatus validateAndExecute_(Context &block);
sd::ResultSet* execute(sd::graph::RandomGenerator& rng, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs, bool isInplace = false);
sd::ResultSet* execute(sd::graph::RandomGenerator& rng, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs, bool isInplace = false);
sd::ResultSet execute(sd::graph::RandomGenerator& rng, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs, bool isInplace = false);
sd::ResultSet execute(sd::graph::RandomGenerator& rng, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs, bool isInplace = false);
Nd4jStatus execute(Context* block) override;
ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override;

View File

@ -74,10 +74,10 @@ namespace sd {
// at first step we build fwd activation
sd::ops::crelu op;
auto tmpResult = op.evaluate({input});
if (tmpResult->status() != ND4J_STATUS_OK)
return tmpResult->status();
if (tmpResult.status() != ND4J_STATUS_OK)
return tmpResult.status();
auto actv = tmpResult->at(0);
auto actv = tmpResult.at(0);
// now we do RELU backward pass
//actv->applyPairwiseTransform(pairwise::RELUDerivativeE, *epsilon, nullptr);
@ -85,17 +85,15 @@ namespace sd {
// now we split updated array into 2 chunks along last dimension
sd::ops::concat_bp opc;
auto dec = opc.evaluate({input, input, actv}, {-1});
if (dec->status() != ND4J_STATUS_OK)
return dec->status();
if (dec.status() != ND4J_STATUS_OK)
return dec.status();
// and now we subtract two parts of epsilons and pass result out
auto pos = dec->at(0);
auto neg = dec->at(1);
auto pos = dec.at(0);
auto neg = dec.at(1);
pos->applyPairwiseTransform(sd::pairwise::Subtract, *neg, *epsilon);
delete tmpResult;
delete dec;
return ND4J_STATUS_OK;
}

View File

@ -102,10 +102,12 @@ namespace sd {
REQUIRE_TRUE(block.width() == 1, 0, "Where op takes either 1 or 3 operands, But got %d operands instead", block.width());
// if (output->isEmpty())
Nd4jLong width = condition->rankOf();
sd::ops::Where op;
std::unique_ptr<ResultSet> res(op.evaluate({condition}));
REQUIRE_OK(res->status());
NDArray* whereTrue = res->at(0);
auto res(op.evaluate({condition}));
REQUIRE_OK(res.status());
NDArray* whereTrue = res.at(0);
if (whereTrue->isEmpty())
return ND4J_STATUS_OK;
for (Nd4jLong outNext = 0; outNext < width; ++outNext) {

View File

@ -65,11 +65,12 @@ namespace sd {
auto gradX = OUTPUT_VARIABLE(0);
auto gradY = OUTPUT_VARIABLE(1);
gradX->assign(epsNext);
sd::ops::floormod op;
std::unique_ptr<ResultSet> tmpResult(op.evaluate({x, y}));
auto tmpResult(op.evaluate({x, y}));
if (gradY->rankOf() == gradX->rankOf())
epsNext->applyPairwiseTransform(pairwise::Multiply, *tmpResult->at(0), *gradY);
epsNext->applyPairwiseTransform(pairwise::Multiply, *tmpResult.at(0), *gradY);
else // epsNext is greater than gradY
{
std::vector<Nd4jLong> dims(epsNext->rankOf() * 2);
@ -77,7 +78,7 @@ namespace sd {
for (Nd4jLong d = 0; d < gap; d++) {
dims[d * 2 + 1] = 1;
}
auto tempIn((*tmpResult->at(0))(dims));
auto tempIn((*tmpResult.at(0))(dims));
(*epsNext)(dims).applyPairwiseTransform(pairwise::Multiply, tempIn, *gradY);
}
return Status::OK();

View File

@ -113,23 +113,21 @@ namespace ops {
originalIndices.linspace(0);
ops::dynamic_partition op;
auto res = op.evaluate({&originalIndices, indices}, {numPartition});
REQUIRE_TRUE(res->status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning.");
REQUIRE_TRUE(res.status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning.");
ops::dynamic_stitch stichOp;
std::vector<NDArray*> partitions(numPartition * 2);
for (size_t i = 0; i < res->size(); i++) {
partitions[i] = res->at(i);
for (size_t i = 0; i < res.size(); i++) {
partitions[i] = res.at(i);
partitions[i + numPartition] = gradOutList[i];
}
auto result = stichOp.evaluate(partitions, {numPartition});
REQUIRE_TRUE(result->status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning.");
result->at(0)->reshapei(outputList[0]->getShapeAsVector());
REQUIRE_TRUE(result.status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning.");
result.at(0)->reshapei(outputList[0]->getShapeAsVector());
outputList[1]->assign(indices);
outputList[0]->assign(result->at(0));
outputList[0]->assign(result.at(0));
// helpers::dynamicPartitionFunctorBP(block.launchContext(), input, indices, gradOutList, outputList);
delete res;
delete result;
return ND4J_STATUS_OK;
}

View File

@ -66,10 +66,10 @@ CUSTOM_OP_IMPL(embedding_lookup, 2, 1, false, 0, 1) {
sd::ops::gather op;
std::unique_ptr<ResultSet> result(op.evaluate({input, indeces}, {0}));
REQUIRE_TRUE(result->status() == Status::OK(), 0, "embedding_lookup: cannot retrieve results from gather op.");
REQUIRE_TRUE(result->at(0)->isSameShape(output), 0, "embedding_lookup: wrong shape of return from gather op.");
output->assign(result->at(0));
auto result(op.evaluate({input, indeces}, {0}));
REQUIRE_TRUE(result.status() == Status::OK(), 0, "embedding_lookup: cannot retrieve results from gather op.");
REQUIRE_TRUE(result.at(0)->isSameShape(output), 0, "embedding_lookup: wrong shape of return from gather op.");
output->assign(result.at(0));
}
return Status::OK();
}

View File

@ -95,8 +95,8 @@ CUSTOM_OP_IMPL(dynamic_bidirectional_rnn, 7, 4, false, 0, 0) {
// forward steps
sd::ops::dynamic_rnn dynamicRnn;
auto resultsFW = dynamicRnn.evaluate({x, WxFW, WhFW, bFW, h0FW, maxTimeStep}, {timeMajor});
hFW->assign(resultsFW->at(0)); // [time x bS x numUnitsFW] or [bS x time x numUnitsFW]
hFWFinal->assign(resultsFW->at(1));
hFW->assign(resultsFW.at(0)); // [time x bS x numUnitsFW] or [bS x time x numUnitsFW]
hFWFinal->assign(resultsFW.at(1));
auto seqLen = maxTimeStep;
if(seqLen == nullptr) {
@ -108,22 +108,17 @@ CUSTOM_OP_IMPL(dynamic_bidirectional_rnn, 7, 4, false, 0, 0) {
// reverse x
sd::ops::reverse_sequence reverse;
auto resultsIn = timeMajor ? reverse.evaluate({x, seqLen}, {0, 1}) : reverse.evaluate({x, seqLen}, {1, 0});
REQUIRE_TRUE (resultsIn->status() == ND4J_STATUS_OK, 0, "dynamic_bidirectional_rnn: there is a problem with reverse on the sequence.");
auto revInput = resultsIn->at(0);
REQUIRE_TRUE (resultsIn.status() == ND4J_STATUS_OK, 0, "dynamic_bidirectional_rnn: there is a problem with reverse on the sequence.");
auto revInput = resultsIn.at(0);
// backward steps
auto resultsBW = dynamicRnn.evaluate({revInput, WxBW, WhBW, bBW, h0BW, maxTimeStep}, {timeMajor});
auto hBWtemp = resultsBW->at(0); // [time x bS x numUnitsBW] or [ bS x time xnumUnitsBW]
hBWFinal->assign(resultsBW->at(1));
auto hBWtemp = resultsBW.at(0); // [time x bS x numUnitsBW] or [ bS x time xnumUnitsBW]
hBWFinal->assign(resultsBW.at(1));
// reverse hBWtemp
auto resultsOut = timeMajor ? reverse.evaluate({hBWtemp, seqLen}, {0, 1}) : reverse.evaluate({hBWtemp, seqLen}, {1, 0});
hBW->assign(resultsOut->at(0));
delete resultsOut;
delete resultsBW;
delete resultsIn;
delete resultsFW;
hBW->assign(resultsOut.at(0));
if(seqLen != maxTimeStep)
delete seqLen;
@ -228,12 +223,6 @@ DECLARE_SHAPE_FN(dynamic_bidirectional_rnn) {
}
}
}

View File

@ -52,14 +52,13 @@ namespace helpers {
sd::ops::unique opUnique;
auto uResult = opUnique.evaluate({&arrayFull});
if (Status::OK() != uResult->status())
if (Status::OK() != uResult.status())
throw std::runtime_error("multiUnique: cannot execute unique op properly.");
auto uniqueVals = uResult->at(0);
auto uniqueVals = uResult.at(0);
bool res = uniqueVals->lengthOf() == arrayFull.lengthOf();
delete uResult;
return res;
}

View File

@ -64,7 +64,7 @@ namespace sd {
block.pushNDArrayListToVariableSpace(block.getNodeId(), 0, arrayList);
}
ResultSet* DeclarableListOp::execute(NDArrayList* list, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs) {
ResultSet DeclarableListOp::execute(NDArrayList* list, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs) {
std::vector<NDArray*> ins(inputs);
std::vector<double> tas(tArgs);
std::vector<int> ias(iArgs);
@ -94,7 +94,7 @@ namespace sd {
return status;
}
ResultSet* DeclarableListOp::execute(NDArrayList* list, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs) {
ResultSet DeclarableListOp::execute(NDArrayList* list, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs) {
VariableSpace varSpace;
int nodeId = 119;
@ -132,8 +132,8 @@ namespace sd {
Nd4jStatus result = this->validateAndExecute(block);
auto res = new ResultSet();
res->setStatus(result);
ResultSet res;
res.setStatus(result);
for (int e = 0; e < DataTypeUtils::max<int>(); e++) {
std::pair<int,int> pair(1, e);
@ -143,10 +143,10 @@ namespace sd {
auto arr = var->getNDArray();
if (arr->isAttached()) {
auto d = arr->detach();
res->push_back(d);
res.push_back(d);
} else {
var->markRemovable(false);
res->push_back(arr);
res.push_back(arr);
}
}
} else

View File

@ -962,12 +962,12 @@ namespace sd {
return execute(&ctx);
}
sd::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs) {
sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs) {
return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<sd::DataType>());
}
template <>
sd::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<int> iArgs) {
sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<int> iArgs) {
std::vector<Nd4jLong> realArgs;
for (auto v:iArgs)
realArgs.emplace_back(v);
@ -976,12 +976,12 @@ namespace sd {
}
template <>
sd::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<Nd4jLong> iArgs) {
sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<Nd4jLong> iArgs) {
return evaluate(inputs, std::vector<double>(), iArgs, std::vector<bool>(), std::vector<sd::DataType>());
}
template <>
sd::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<float> tArgs) {
sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<float> tArgs) {
std::vector<double> realArgs;
for (auto v:tArgs)
realArgs.emplace_back(v);
@ -990,21 +990,21 @@ namespace sd {
}
template <>
sd::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<double> tArgs) {
sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<double> tArgs) {
return evaluate(inputs, tArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<sd::DataType>());
}
template <>
sd::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<bool> bArgs) {
sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<bool> bArgs) {
return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), bArgs, std::vector<sd::DataType>());
}
template <>
sd::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<sd::DataType> bArgs) {
sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<sd::DataType> bArgs) {
return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), std::vector<bool>(), bArgs);
}
sd::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs, const std::vector<sd::DataType> &dArgs, bool isInplace) {
sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs, const std::vector<sd::DataType> &dArgs, bool isInplace) {
VariableSpace variableSpace;
//ResultSet arrayList;
FlowPath fp;
@ -1041,11 +1041,11 @@ namespace sd {
block.getDArguments()->push_back(dArgs.at(e));
Nd4jStatus status = this->execute(&block);
auto arrayList = new ResultSet();
ResultSet arrayList;
if (isInplace)
arrayList->setNonRemovable();
arrayList.setNonRemovable();
arrayList->setStatus(status);
arrayList.setStatus(status);
if (status != ND4J_STATUS_OK)
return arrayList;
@ -1058,23 +1058,23 @@ namespace sd {
if (!arr->isAttached()) {
var->markRemovable(false);
arr->setContext(sd::LaunchContext::defaultContext());
arrayList->push_back(arr);
arrayList.push_back(arr);
} else {
arrayList->push_back(arr->detach());
arrayList.push_back(arr->detach());
}
} else
break;
}
} else {
for (auto v:inputs) {
arrayList->push_back(v);
arrayList.push_back(v);
}
}
return arrayList;
}
sd::ResultSet* sd::ops::DeclarableOp::execute(const sd::OpArgsHolder& holder, bool isInplace) {
sd::ResultSet sd::ops::DeclarableOp::execute(const sd::OpArgsHolder& holder, bool isInplace) {
// FIXME: add DArgs to OpArgsHolder
return evaluate(holder.getInArrs(), holder.getTArgs(), holder.getIArgs(), holder.getBArgs(), std::vector<sd::DataType>(), isInplace);
}

View File

@ -357,20 +357,20 @@ namespace sd {
return DeclarableOp::execute(block);
}
sd::ResultSet* LegacyRandomOp::execute(sd::graph::RandomGenerator& rng, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs, bool isInplace) {
sd::ResultSet LegacyRandomOp::execute(sd::graph::RandomGenerator& rng, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs, bool isInplace) {
std::vector<NDArray*> ins(inputs);
std::vector<double> tas(tArgs);
std::vector<int> ias(iArgs);
return this->execute(rng, ins, tas, ias, isInplace);
}
sd::ResultSet* LegacyRandomOp::execute(sd::graph::RandomGenerator& rng, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs, bool isInplace) {
sd::ResultSet LegacyRandomOp::execute(sd::graph::RandomGenerator& rng, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs, bool isInplace) {
VariableSpace variableSpace;
auto arrayList = new ResultSet();
ResultSet arrayList;
//ResultSet arrayList;
if (isInplace)
arrayList->setNonRemovable();
arrayList.setNonRemovable();
int cnt = -1;
std::vector<int> in;
@ -398,7 +398,7 @@ namespace sd {
block.getIArguments()->emplace_back(iArgs.at(e));
Nd4jStatus status = this->execute(&block);
arrayList->setStatus(status);
arrayList.setStatus(status);
if (status != ND4J_STATUS_OK)
return arrayList;
@ -410,9 +410,9 @@ namespace sd {
auto arr = var->getNDArray();
if (!arr->isAttached()) {
var->markRemovable(false);
arrayList->push_back(arr);
arrayList.push_back(arr);
} else {
arrayList->push_back(arr->detach());
arrayList.push_back(arr->detach());
}
} else
break;

View File

@ -44,9 +44,7 @@ TEST_F(AttentionTests, basic_dot_product_attention) {
sd::ops::dot_product_attention op;
auto result = op.evaluate({&queries, &keys, &values}, {1, 0});
ASSERT_EQ(Status::OK(), result->status());
delete result;
ASSERT_EQ(Status::OK(), result.status());
}
/*
@ -72,9 +70,7 @@ TEST_F(AttentionTests, basic_dot_product_attention_with_weights) {
sd::ops::dot_product_attention op;
auto result = op.evaluate({&queries, &keys, &values}, {1, 1});
ASSERT_EQ(Status::OK(), result->status());
delete result;
ASSERT_EQ(Status::OK(), result.status());
}
TEST_F(AttentionTests, basic_dot_product_attention_with_mask) {
@ -86,9 +82,7 @@ TEST_F(AttentionTests, basic_dot_product_attention_with_mask) {
sd::ops::dot_product_attention op;
auto result = op.evaluate({&queries, &keys, &values, &mask}, {1, 0});
ASSERT_EQ(Status::OK(), result->status());
delete result;
ASSERT_EQ(Status::OK(), result.status());
}
/*
@ -118,9 +112,7 @@ TEST_F(AttentionTests, multi_head_input_dot_product_attention_with_mask) {
sd::ops::dot_product_attention op;
auto result = op.evaluate({&queries, &keys, &values, &mask}, {1, 0});
ASSERT_EQ(Status::OK(), result->status());
delete result;
ASSERT_EQ(Status::OK(), result.status());
}
/*
@ -154,9 +146,7 @@ TEST_F(AttentionTests, basic_multi_head_dot_product_attention) {
sd::ops::multi_head_dot_product_attention op;
auto result = op.evaluate({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo}, {1, 0});
ASSERT_EQ(Status::OK(), result->status());
delete result;
ASSERT_EQ(Status::OK(), result.status());
}
/*
@ -198,9 +188,7 @@ TEST_F(AttentionTests, basic_multi_head_dot_product_attention_with_mask) {
sd::ops::multi_head_dot_product_attention op;
auto result = op.evaluate({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo, &mask}, {1, 0});
ASSERT_EQ(Status::OK(), result->status());
delete result;
ASSERT_EQ(Status::OK(), result.status());
}
/*

View File

@ -39,13 +39,11 @@ TEST_F(BackpropTests, Test_Add_1) {
sd::ops::add_bp op;
auto result = op.evaluate({&x, &y, &e});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto eps = result->at(0);
auto grad = result->at(1);
auto eps = result.at(0);
auto grad = result.at(1);
ASSERT_TRUE(x.isSameShape(eps));
ASSERT_TRUE(y.isSameShape(grad));
delete result;
}

View File

@ -137,13 +137,12 @@ TEST_F(BooleanOpsTests, test_where_1) {
sd::ops::choose op;
auto result = op.evaluate({&x, &y}, {3});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
//z->printIndexedBuffer("z");
ASSERT_EQ(e, *z);
delete result;
}

View File

@ -48,9 +48,9 @@ TEST_F(BroadcastableOpsTests, Test_Add_1) {
sd::ops::add op;
auto result = op.evaluate({&x, &y});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
//exp.printIndexedBuffer("E A");
//z->printIndexedBuffer("Z");
@ -58,7 +58,6 @@ TEST_F(BroadcastableOpsTests, Test_Add_1) {
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
@ -75,14 +74,12 @@ TEST_F(BroadcastableOpsTests, Test_Multiply_1) {
sd::ops::multiply op;
auto result = op.evaluate({&x, &y});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
@ -100,14 +97,13 @@ TEST_F(BroadcastableOpsTests, Test_SquaredSubtract_1) {
sd::ops::squaredsubtract op;
auto result = op.evaluate({&x, &y});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
@ -119,14 +115,12 @@ TEST_F(BroadcastableOpsTests, Test_ScalarBroadcast_1) {
sd::ops::subtract op;
auto result = op.evaluate({&x, &y});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
@ -138,14 +132,12 @@ TEST_F(BroadcastableOpsTests, Test_ScalarBroadcast_2) {
sd::ops::add op;
auto result = op.evaluate({&x, &y});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
@ -156,14 +148,12 @@ TEST_F(BroadcastableOpsTests, Test_Maximum_1) {
sd::ops::maximum op;
auto result = op.evaluate({&x, &row});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
@ -174,15 +164,14 @@ TEST_F(BroadcastableOpsTests, Test_Minimum_1) {
sd::ops::minimum op;
auto result = op.evaluate({&x, &col});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
@ -283,14 +272,13 @@ TEST_F(BroadcastableOpsTests, Test_Scalar_Add_1) {
sd::ops::add op;
auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
@ -333,11 +321,9 @@ TEST_F(BroadcastableOpsTests, Test_Subtract_2) {
sd::ops::subtract op;
auto result = op.evaluate({&x, &y});
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(e.equalsTo(z));
delete result;
}
TEST_F(BroadcastableOpsTests, Test_Subtract_3) {
@ -511,13 +497,12 @@ TEST_F(BroadcastableOpsTests, Test_Multiply_7) {
sd::ops::multiply op;
auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(e.equalsTo(z));
delete result;
}
TEST_F(BroadcastableOpsTests, Test_Multiply_8) {
@ -527,13 +512,11 @@ TEST_F(BroadcastableOpsTests, Test_Multiply_8) {
sd::ops::multiply op;
auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(e.equalsTo(z));
delete result;
}
//////////////////////////////////////////////////////////////////////
@ -606,14 +589,12 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_3) {
sd::ops::maximum op;
auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(*z));
delete result;
}
TEST_F(BroadcastableOpsTests, broadcast_empty_4) {
@ -625,14 +606,13 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_4) {
sd::ops::maximum op;
auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(*z));
delete result;
}
TEST_F(BroadcastableOpsTests, broadcast_empty_5) {
@ -644,14 +624,13 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_5) {
sd::ops::realdiv op;
auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(*z));
delete result;
}
TEST_F(BroadcastableOpsTests, broadcast_empty_6) {
@ -663,14 +642,13 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_6) {
sd::ops::realdiv op;
auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(*z));
delete result;
}
TEST_F(BroadcastableOpsTests, broadcast_empty_7) {
@ -682,14 +660,12 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_7) {
sd::ops::realdiv op;
auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(*z));
delete result;
}
@ -718,15 +694,13 @@ TEST_F(BroadcastableOpsTests, broadcast_bool_empty_2) {
sd::ops::greater op;
auto result = op.evaluate({&x, &y});
auto z = result->at(0);
auto z = result.at(0);
// z->printShapeInfo("z");
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(*z));
delete result;
}
TEST_F(BroadcastableOpsTests, broadcast_bool_1) {

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -48,9 +48,7 @@ TEST_F(DataTypesValidationTests, Basic_Test_1) {
sd::ops::conv2d op;
auto result = op.evaluate({&input, &weights}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0});
ASSERT_EQ(ND4J_STATUS_VALIDATION, result->status());
delete result;
ASSERT_EQ(ND4J_STATUS_VALIDATION, result.status());
}
TEST_F(DataTypesValidationTests, Basic_Test_2) {
@ -63,13 +61,12 @@ TEST_F(DataTypesValidationTests, Basic_Test_2) {
sd::ops::conv2d op;
auto result = op.evaluate({&input, &weights}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -47,13 +47,11 @@ TEST_F(DeclarableOpsTests16, scatter_upd_1) {
sd::ops::scatter_upd op;
auto result = op.evaluate({ &x, &y, &w });
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(DeclarableOpsTests16, scatter_upd_2) {
@ -67,13 +65,11 @@ TEST_F(DeclarableOpsTests16, scatter_upd_2) {
sd::ops::scatter_upd op;
auto result = op.evaluate({ &x, &indices, &updates });
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(DeclarableOpsTests16, scatter_upd_3) {
@ -136,13 +132,11 @@ TEST_F(DeclarableOpsTests16, test_hamming_distance_1) {
sd::ops::bits_hamming_distance op;
auto result = op.evaluate({ &x, &y });
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(DeclarableOpsTests16, test_knn_mindistance_1) {
@ -167,10 +161,8 @@ TEST_F(DeclarableOpsTests16, test_empty_cast_1) {
sd::ops::cast op;
auto result = op.evaluate({&x}, {10});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(e, *result->at(0));
delete result;
ASSERT_EQ(Status::OK(), result.status());
ASSERT_EQ(e, *result.at(0));
}
TEST_F(DeclarableOpsTests16, test_range_1) {
@ -717,8 +709,6 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_1) {
}
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_2) {
auto rgb = NDArrayFactory::create<float>('c', { 5, 3, 4 },
@ -767,7 +757,6 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_2) {
}
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_3) {
auto rgb = NDArrayFactory::create<float>('c', { 4, 3 },
@ -798,7 +787,6 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_3) {
}
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_4) {
auto rgb = NDArrayFactory::create<float>('c', { 3, 4 },
@ -826,7 +814,6 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_4) {
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual));
}
@ -850,7 +837,6 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_5) {
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual));
}
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_6) {
@ -891,8 +877,6 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_6) {
}
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_1) {
auto yiqs = NDArrayFactory::create<float>('c', { 5, 4, 3 }, {
@ -937,8 +921,6 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_1) {
}
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_2) {
auto yiqs = NDArrayFactory::create<float>('c', { 5, 3, 4 }, {
@ -983,7 +965,6 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_2) {
}
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_3) {
auto yiqs = NDArrayFactory::create<float>('c', { 4, 3 }, {
@ -1010,7 +991,6 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_3) {
}
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_4) {
auto yiqs = NDArrayFactory::create<float>('c', { 3, 4 }, {
@ -1037,8 +1017,6 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_4) {
}
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_5) {
auto yiqs = NDArrayFactory::create<float>('c', { 3 }, {
@ -1061,7 +1039,6 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_5) {
#endif
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual));
}
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_6) {
@ -1096,5 +1073,4 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_6) {
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual));
}

View File

@ -49,9 +49,7 @@ TEST_F(DeclarableOpsTests17, test_sparse_to_dense_1) {
sd::ops::compat_sparse_to_dense op;
auto result = op.evaluate({&ranges, &shape, &values, &def});
ASSERT_EQ(Status::OK(), result->status());
delete result;
ASSERT_EQ(Status::OK(), result.status());
}
TEST_F(DeclarableOpsTests17, test_sparse_to_dense_2) {
@ -64,9 +62,8 @@ TEST_F(DeclarableOpsTests17, test_sparse_to_dense_2) {
sd::ops::compat_sparse_to_dense op;
auto result = op.evaluate({&ranges, &shape, &values, &def});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
delete result;
}
TEST_F(DeclarableOpsTests17, test_compat_string_split_1) {
@ -78,11 +75,11 @@ TEST_F(DeclarableOpsTests17, test_compat_string_split_1) {
sd::ops::compat_string_split op;
auto result = op.evaluate({&x, &delimiter});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(2, result->size());
ASSERT_EQ(Status::OK(), result.status());
ASSERT_EQ(2, result.size());
auto z0 = result->at(0);
auto z1 = result->at(1);
auto z0 = result.at(0);
auto z1 = result.at(1);
ASSERT_TRUE(exp0.isSameShape(z0));
ASSERT_TRUE(exp1.isSameShape(z1));
@ -90,5 +87,4 @@ TEST_F(DeclarableOpsTests17, test_compat_string_split_1) {
ASSERT_EQ(exp0, *z0);
ASSERT_EQ(exp1, *z1);
delete result;
}

View File

@ -62,10 +62,8 @@ TEST_F(DeclarableOpsTests19, test_conv1d_bp_1) {
sd::ops::conv1d_bp op;
auto result = op.evaluate({&t, &u, &v}, {3, 2, 0, 1, 2,0});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
delete result;
}
TEST_F(DeclarableOpsTests19, test_squeeze_1) {

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -51,14 +51,12 @@ TEST_F(DeclarableOpsTestsCuda1, Test_CHOOSE_SCALAR_LARGE) {
sd::ops::choose op;
//greater than test
auto result = op.evaluate({&x}, {0.0},{3});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(1);
auto z = result.at(1);
ASSERT_EQ(148,z->e<double>(0));
//ASSERT_TRUE(exp.isSameShape(z));
delete result;
}
/*

View File

@ -67,9 +67,9 @@ TEST_F(EmptyTests, Test_Concat_1) {
sd::ops::concat op;
auto result = op.evaluate({empty, vector}, {}, {0});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
// z->printShapeInfo("z shape");
// z->printIndexedBuffer("z buffr");
@ -78,7 +78,6 @@ TEST_F(EmptyTests, Test_Concat_1) {
delete empty;
delete vector;
delete result;
}
@ -92,9 +91,9 @@ TEST_F(EmptyTests, Test_Concat_2) {
sd::ops::concat op;
auto result = op.evaluate({empty, scalar1, scalar2}, {}, {0});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
// z->printShapeInfo("z shape");
// z->printIndexedBuffer("z buffr");
@ -104,7 +103,6 @@ TEST_F(EmptyTests, Test_Concat_2) {
delete empty;
delete scalar1;
delete scalar2;
delete result;
}
TEST_F(EmptyTests, Test_Concat_3) {
@ -117,13 +115,12 @@ TEST_F(EmptyTests, Test_Concat_3) {
sd::ops::concat op;
auto result = op.evaluate({&empty, &scalar1, &scalar2}, {}, {0});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_EQ(exp, *z);
delete result;
}
TEST_F(EmptyTests, Test_Concat_4) {
@ -136,13 +133,11 @@ TEST_F(EmptyTests, Test_Concat_4) {
sd::ops::concat op;
auto result = op.evaluate({&scalar1, &empty, &scalar2}, {}, {0});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_EQ(exp, *z);
delete result;
}
TEST_F(EmptyTests, Test_Reshape_1) {
@ -153,12 +148,11 @@ TEST_F(EmptyTests, Test_Reshape_1) {
sd::ops::reshape op;
auto result = op.evaluate({&vector, empty}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
ASSERT_EQ(exp, *result->at(0));
ASSERT_EQ(exp, *result.at(0));
delete empty;
delete result;
}
TEST_F(EmptyTests, Test_Reshape_3) {
@ -168,14 +162,13 @@ TEST_F(EmptyTests, Test_Reshape_3) {
sd::ops::reshape op;
auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(e.isSameShape(z));
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(EmptyTests, Test_dup_1) {
@ -197,12 +190,11 @@ TEST_F(EmptyTests, test_empty_scatter_1) {
sd::ops::scatter_upd op;
auto result = op.evaluate({&x, &indices, &updates}, {}, {}, {true});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_EQ(x, *z);
delete result;
}
TEST_F(EmptyTests, test_empty_scatter_2) {
@ -288,17 +280,15 @@ TEST_F(EmptyTests, test_empty_reshape_1) {
sd::ops::reshape op;
auto result0 = op.evaluate({&x0, &shape0}, {}, {});
ASSERT_EQ(Status::OK(), result0->status());
auto z0 = result0->at(0);
ASSERT_EQ(Status::OK(), result0.status());
auto z0 = result0.at(0);
ASSERT_EQ(e0, *z0);
auto result1 = op.evaluate({&x1, &shape1}, {}, {});
ASSERT_EQ(Status::OK(), result1->status());
auto z1 = result1->at(0);
ASSERT_EQ(Status::OK(), result1.status());
auto z1 = result1.at(0);
ASSERT_EQ(e1, *z1);
delete result0;
delete result1;
}
@ -309,12 +299,11 @@ TEST_F(EmptyTests, test_empty_matmul_1) {
sd::ops::matmul op;
auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(EmptyTests, test_empty_matmul_2) {
@ -324,10 +313,8 @@ TEST_F(EmptyTests, test_empty_matmul_2) {
sd::ops::matmul op;
auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_EQ(e, *z);
delete result;
}

View File

@ -1889,20 +1889,19 @@ TEST_F(HelpersTests1, OpArgsHolder_test3) {
OpArgsHolder holderFF({&input}, {}, {2, 3});
sd::ops::tile opFF; // the kind of op doesn't matter, we simply check here whether op.execute() works with OpArgsHolder correctly
auto results = opFF.execute(holderFF);
auto tiled = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
auto tiled = results.at(0);
ASSERT_EQ(Status::OK(), results.status());
ASSERT_TRUE(exp.isSameShape(tiled));
ASSERT_TRUE(exp.equalsTo(tiled));
delete results;
OpArgsHolder holderBP = holderFF.createArgsHolderForBP({&gradO}, true);
sd::ops::tile_bp opBP;
results = opBP.execute(holderBP);
auto gradI = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
auto gradI = results.at(0);
ASSERT_EQ(Status::OK(), results.status());
ASSERT_TRUE(gradIExp.isSameShape(gradI));
ASSERT_TRUE(gradIExp.equalsTo(gradI));
delete results;
}

View File

@ -47,13 +47,13 @@ TEST_F(IndexingTests, StridedSlice_1) {
sd::ops::strided_slice op;
auto result = op.evaluate({&x, &begin, &end, &strides}, {}, {0,0,0,0,0}); //, 2,2,0, 3,3,3, 1,1,1});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
@ -66,14 +66,14 @@ TEST_F(IndexingTests, StridedSlice_2) {
sd::ops::strided_slice op;
auto result = op.evaluate({&x}, {}, {0,0,0,0,0, 3,2,0, 5,5,3, 1,1,1});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
@ -86,14 +86,14 @@ TEST_F(IndexingTests, StridedSlice_3) {
sd::ops::strided_slice op;
auto result = op.evaluate({&x}, {}, {0,0,0,0,0, 3,2,0, 5,5,3, 1,1,2});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
@ -109,15 +109,15 @@ TEST_F(IndexingTests, SimpleSlice_1) {
sd::ops::slice op;
auto result = op.evaluate({&input}, {}, {1,0,0, 1,1,3});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
@ -135,15 +135,15 @@ TEST_F(IndexingTests, SimpleSlice_2) {
sd::ops::slice op;
auto result = op.evaluate({&input}, {}, {1,0,0, 1,2,3});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(IndexingTests, SimpleSlice_3) {
@ -160,15 +160,15 @@ TEST_F(IndexingTests, SimpleSlice_3) {
sd::ops::slice op;
auto result = op.evaluate({&input}, {}, {1,0,0, 2,1,3});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(IndexingTests, SimpleSlice_4) {
@ -180,14 +180,14 @@ TEST_F(IndexingTests, SimpleSlice_4) {
sd::ops::slice op;
auto result = op.evaluate({&input, &start, &stop});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
@ -204,16 +204,16 @@ TEST_F(IndexingTests, MaskedSlice_0) {
sd::ops::strided_slice op;
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,0, 1, 2, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
// z->printShapeInfo("z");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
@ -230,14 +230,14 @@ TEST_F(IndexingTests, MaskedSlice_00) {
sd::ops::strided_slice op;
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,0, 1, 1, 2, 3, 1, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
@ -254,16 +254,16 @@ TEST_F(IndexingTests, MaskedSlice_1) {
sd::ops::strided_slice op;
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,1, 1, 2, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
// z->printShapeInfo("z");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(IndexingTests, MaskedSlice_2) {
@ -275,14 +275,14 @@ TEST_F(IndexingTests, MaskedSlice_2) {
sd::ops::strided_slice op;
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,1, 1, 0, 0, 3, 3, 3, 1, 1, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
@ -295,14 +295,14 @@ TEST_F(IndexingTests, MaskedSlice_3) {
sd::ops::strided_slice op;
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,2, 1, 0, 0, 3, 3, 3, 1, 1, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
@ -315,15 +315,15 @@ TEST_F(IndexingTests, MaskedSlice_4) {
sd::ops::strided_slice op;
auto result = op.evaluate({&matrix}, {}, {0,0,0,0, 3, 1, 0, 0, 3, 3, 3, 1, 1, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(IndexingTests, Live_Slice_1) {
@ -338,16 +338,16 @@ TEST_F(IndexingTests, Live_Slice_1) {
sd::ops::strided_slice op;
auto result = op.evaluate({&matrix, &begin, &end, &stride}, {}, {0,0,0,0,3});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
// z->printShapeInfo("z shape");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
@ -361,14 +361,14 @@ TEST_F(IndexingTests, Test_StridedSlice_1) {
sd::ops::strided_slice op;
auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(IndexingTests, Test_StridedSlice_2) {
@ -381,16 +381,16 @@ TEST_F(IndexingTests, Test_StridedSlice_2) {
sd::ops::strided_slice op;
auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
// z->printIndexedBuffer("Z");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
@ -404,14 +404,14 @@ TEST_F(IndexingTests, Test_StridedSlice_3) {
sd::ops::strided_slice op;
auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
@ -426,16 +426,16 @@ TEST_F(IndexingTests, Test_StridedSlice_4) {
auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
// auto result = op.execute({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1, 0, 1, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
//z->printIndexedBuffer("Z");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(IndexingTests, Test_Subarray_Strided_1) {
@ -458,13 +458,13 @@ TEST_F(IndexingTests, MaskedSlice_5) {
sd::ops::strided_slice<float> op;
auto result = op.execute({&matrix}, {}, {0,0,0,0,2, 1, 0, 0, 3, 3, 3});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
*/

View File

@ -64,13 +64,13 @@ TEST_F(LegacyOpsTests, TransformTests_2) {
sd::ops::LegacyTransformSameOp op(transform::Neg); // Neg
auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(1, result->size());
ASSERT_EQ(1, result.size());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(LegacyOpsTests, Reciprocal_1) {
@ -121,12 +121,12 @@ TEST_F(LegacyOpsTests, PWT_Tests_2) {
sd::ops::LegacyPairwiseTransformOp op(pairwise::Multiply); // Multiply
auto result = op.evaluate({&x, &y}, {}, {});
auto z = result->at(0);
auto z = result.at(0);
//z->printBuffer("Z");
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(LegacyOpsTests, Scalar_Test_1) {
@ -154,10 +154,10 @@ TEST_F(LegacyOpsTests, Scalar_Test_2) {
sd::ops::LegacyScalarOp op(scalar::Add, y);
auto result = op.evaluate({&x}, {}, {});
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
@ -169,14 +169,14 @@ TEST_F(LegacyOpsTests, ReduceTests_1) {
auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(1, result->size());
ASSERT_EQ(1, result.size());
auto z = result->at(0);
auto z = result.at(0);
// z->printBuffer("ReduceTest1");
ASSERT_TRUE(z->isScalar());
ASSERT_NEAR(x.sumNumber().e<float>(0), z->e<float>(0), 1e-5f);
delete result;
}
@ -188,16 +188,16 @@ TEST_F(LegacyOpsTests, ReduceTests_2) {
auto axis = NDArrayFactory::create<Nd4jLong>('c', {1}, {1});
auto result = op.evaluate({&x, &axis}, {}, {});
ASSERT_EQ(1, result->size());
ASSERT_EQ(1, result.size());
auto z = result->at(0);
auto z = result.at(0);
auto exp = x.reduceAlongDimension(reduce::Sum, {1});
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
@ -209,15 +209,15 @@ TEST_F(LegacyOpsTests, ReduceTests_3) {
sd::ops::LegacyReduceSameOp op(reduce::Sum);
auto result = op.evaluate({&x, &indices}, {}, {});
auto z = result->at(0);
auto z = result.at(0);
auto exp = x.reduceAlongDimension(reduce::Sum,{1});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
@ -229,16 +229,16 @@ TEST_F(LegacyOpsTests, ReduceTests_4) {
sd::ops::LegacyReduceSameOp op(reduce::Sum);
auto result = op.evaluate({&x, &indices}, {}, {}, {true});
auto z = result->at(0);
auto z = result.at(0);
auto exp = x.reduceAlongDimension(reduce::Sum, {1}, true);
// indices.printShapeInfo("Indices shape");
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
// z->printIndexedBuffer("Output reduce 4");
// exp.printIndexedBuffer("Expected reduce 4");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(LegacyOpsTests, ReduceTests_5) {
@ -249,14 +249,14 @@ TEST_F(LegacyOpsTests, ReduceTests_5) {
auto result = op.evaluate({&x});
ASSERT_EQ(1, result->size());
ASSERT_EQ(1, result.size());
auto z = result->at(0);
auto z = result.at(0);
// z->printBuffer("ReduceTest1");
ASSERT_TRUE(z->isScalar());
ASSERT_NEAR(x.meanNumber().e<float>(0), z->e<float>(0), 1e-5f);
delete result;
}
@ -268,16 +268,16 @@ TEST_F(LegacyOpsTests, ReduceTests_6) {
auto result = op.evaluate({&x, &axis}, {}, {});
ASSERT_EQ(1, result->size());
ASSERT_EQ(1, result.size());
auto z = result->at(0);
auto z = result.at(0);
auto exp = x.reduceAlongDimension(reduce::Mean, {1});
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
@ -289,15 +289,15 @@ TEST_F(LegacyOpsTests, ReduceTests_7) {
sd::ops::LegacyReduceFloatOp op(reduce::Mean);
auto result = op.evaluate({&x, &indices}, {}, {});
auto z = result->at(0);
auto z = result.at(0);
auto exp = x.reduceAlongDimension(reduce::Mean,{1});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
@ -309,17 +309,17 @@ TEST_F(LegacyOpsTests, ReduceTests_8) {
sd::ops::LegacyReduceFloatOp op(reduce::Mean);
auto result = op.evaluate({&x, &indices}, {}, {}, {true});
auto z = result->at(0);
auto z = result.at(0);
auto exp = x.reduceAlongDimension(reduce::Mean, {1}, true);
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
// z->printIndexedBuffer("Reduce8 output");
// z->printShapeInfo("Reduce8 shape");
// exp.printShapeInfo("Reduce8 expected shape");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
@ -331,14 +331,14 @@ TEST_F(LegacyOpsTests, IndexReduceTests_1) {
auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(1, result->size());
ASSERT_EQ(1, result.size());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(z->isScalar());
ASSERT_EQ(24, z->e<int>(0));
delete result;
}
@ -351,9 +351,9 @@ TEST_F(LegacyOpsTests, IndexReduceTests_2) {
auto result = op.evaluate({&x, &indices}, {}, {});
ASSERT_EQ(1, result->size());
ASSERT_EQ(1, result.size());
auto z = result->at(0);
auto z = result.at(0);
// z->printIndexedBuffer("Hello indexreduce2");
ASSERT_TRUE(exp.equalsTo(z));
//ASSERT_EQ(4, z->e<int>(0));
@ -362,7 +362,7 @@ TEST_F(LegacyOpsTests, IndexReduceTests_2) {
//ASSERT_EQ(4, z->e<int>(3));
//ASSERT_EQ(4, z->e<int>(4));
delete result;
}
TEST_F(LegacyOpsTests, BroadcastingTests_1) {

View File

@ -39,7 +39,7 @@ TEST_F(ListOperationsTests, BasicTest_Write_1) {
auto result = op.execute(&list, {&x}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
ASSERT_EQ(1, list.elements());
@ -47,8 +47,8 @@ TEST_F(ListOperationsTests, BasicTest_Write_1) {
ASSERT_EQ(2, list.elements());
delete result;
delete result2;
}
TEST_F(ListOperationsTests, BasicTest_Stack_1) {
@ -66,15 +66,15 @@ TEST_F(ListOperationsTests, BasicTest_Stack_1) {
auto result = op.execute(&list, {}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
// z->printShapeInfo();
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(ListOperationsTests, BasicTest_UnStackList_1) {
@ -93,10 +93,10 @@ TEST_F(ListOperationsTests, BasicTest_UnStackList_1) {
auto result = op.execute(&list, {&x}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
ASSERT_EQ(list.elements(), 10);
// auto z = result->at(0);
// auto z = result.at(0);
// z->printShapeInfo("The first of");
// ASSERT_TRUE(exp.isSameShape(z));
// ASSERT_TRUE(exp.equalsTo(z));
@ -107,7 +107,7 @@ TEST_F(ListOperationsTests, BasicTest_UnStackList_1) {
delete row;
}
delete result;
}
//TEST_F(ListOperationsTests, BasicTest_UnStackList_2) {
@ -126,20 +126,20 @@ TEST_F(ListOperationsTests, BasicTest_UnStackList_1) {
//
// auto result = op.execute(nullptr, {&x}, {}, {0});
//
// ASSERT_EQ(ND4J_STATUS_OK, result->status());
// ASSERT_EQ(ND4J_STATUS_OK, result.status());
// ASSERT_EQ(result->size(), 10);
//
// // auto z = result->at(0);
// // auto z = result.at(0);
//// z->printShapeInfo("The first of");
//// ASSERT_TRUE(exp.isSameShape(z));
//// ASSERT_TRUE(exp.equalsTo(z));
// for (int e = 0; e < 10; e++) {
// auto row = result->at(e);
// auto row = result.at(e);
// ASSERT_TRUE(row->equalsTo(tads->at(e)));
// //list.write(e, row);
// }
//
// delete result;
//
// delete tads;
//}
@ -160,14 +160,14 @@ TEST_F(ListOperationsTests, BasicTest_Read_1) {
auto result = op.execute(&list, {}, {}, {4});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(ListOperationsTests, BasicTest_Pick_1) {
@ -192,14 +192,14 @@ TEST_F(ListOperationsTests, BasicTest_Pick_1) {
sd::ops::pick_list op;
auto result = op.execute(&list, {}, {}, {1, 1, 3, 3});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(ListOperationsTests, BasicTest_Size_1) {
@ -217,14 +217,14 @@ TEST_F(ListOperationsTests, BasicTest_Size_1) {
auto result = op.execute(&list, {}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(ListOperationsTests, BasicTest_Create_1) {
@ -235,12 +235,12 @@ TEST_F(ListOperationsTests, BasicTest_Create_1) {
auto result = op.execute(nullptr, {&matrix}, {}, {1, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
// we return flow as well
ASSERT_EQ(1, result->size());
ASSERT_EQ(1, result.size());
delete result;
}
TEST_F(ListOperationsTests, BasicTest_Split_1) {
@ -283,7 +283,7 @@ TEST_F(ListOperationsTests, BasicTest_Split_1) {
sd::ops::split_list op;
auto result = op.execute(&list, {&matrix, &lengths}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
ASSERT_EQ(3, list.height());
@ -296,7 +296,7 @@ TEST_F(ListOperationsTests, BasicTest_Split_1) {
ASSERT_TRUE(exp2.isSameShape(list.readRaw(2)));
ASSERT_TRUE(exp2.equalsTo(list.readRaw(2)));
delete result;
}
TEST_F(ListOperationsTests, BasicTest_Scatter_1) {
@ -319,7 +319,7 @@ TEST_F(ListOperationsTests, BasicTest_Scatter_1) {
sd::ops::scatter_list op;
auto result = op.execute(&list, {&indices, &matrix, &s}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
for (int e = 0; e < 10; e++) {
auto row = tads.at(9 - e);
@ -329,7 +329,7 @@ TEST_F(ListOperationsTests, BasicTest_Scatter_1) {
ASSERT_TRUE(chunk->equalsTo(row));
}
delete result;
}
TEST_F(ListOperationsTests, BasicTest_Clone_1) {
@ -385,10 +385,10 @@ TEST_F(ListOperationsTests, BasicTest_Gather_1) {
sd::ops::gather_list op;
auto result = op.execute(&list, {&indices}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(1, result->size());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
ASSERT_EQ(1, result.size());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
@ -397,7 +397,7 @@ TEST_F(ListOperationsTests, BasicTest_Gather_1) {
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(ListOperationsTests, GraphTests_Sequential_1) {

View File

@ -134,13 +134,11 @@ TEST_F(MultiDataTypeTests, Basic_Test_7) {
sd::ops::add op;
auto result = op.evaluate({&x, &y});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_EQ(e, *z);
delete result;
}
////////////////////////////////////////////////////////////////////////////////

View File

@ -66,7 +66,7 @@ TEST_F(NlpTests, basic_sg_hs_test_1) {
sd::ops::skipgram op;
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true);
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto row0 = syn0({0,1, 0,0}, true);
auto row1 = syn1({1,2, 0,0}, true);
@ -74,7 +74,7 @@ TEST_F(NlpTests, basic_sg_hs_test_1) {
ASSERT_EQ(exp0, row0);
ASSERT_EQ(exp1, row1);
delete result;
}
TEST_F(NlpTests, basic_sg_hs_test_2) {
@ -107,7 +107,7 @@ TEST_F(NlpTests, basic_sg_hs_test_2) {
sd::ops::skipgram op;
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true);
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto row0 = syn0({0,1, 0,0}, true);
auto row1 = syn1({1,2, 0,0}, true);
@ -117,7 +117,7 @@ TEST_F(NlpTests, basic_sg_hs_test_2) {
ASSERT_EQ(exp1, row1);
ASSERT_EQ(exp2, row2);
delete result;
}
TEST_F(NlpTests, basic_sg_hs_test_3) {
@ -159,7 +159,7 @@ TEST_F(NlpTests, basic_sg_hs_test_3) {
sd::ops::skipgram op;
auto result0 = op.evaluate({&target, &ngStarter, &indices0, &codes00, &syn00, &syn10, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true);
auto result1 = op.evaluate({&target, &ngStarter, &indices1, &codes01, &syn01, &syn11, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true);
ASSERT_EQ(Status::OK(), result0->status());
ASSERT_EQ(Status::OK(), result0.status());
auto row00 = syn00({0,1, 0,0}, true);
auto row01 = syn01({0,1, 0,0}, true);
@ -168,9 +168,6 @@ TEST_F(NlpTests, basic_sg_hs_test_3) {
ASSERT_EQ(row2, row1);
ASSERT_EQ(row00, row01);
delete result0;
delete result1;
}
TEST_F(NlpTests, basic_sg_hs_ns_test_1) {
@ -192,9 +189,9 @@ TEST_F(NlpTests, basic_sg_hs_ns_test_1) {
sd::ops::skipgram op;
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {3}, {false}, {}, true);
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
delete result;
}
TEST_F(NlpTests, basic_sg_ns_test_1) {
@ -227,14 +224,14 @@ TEST_F(NlpTests, basic_sg_ns_test_1) {
sd::ops::skipgram op;
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {1, 1}, {false}, {}, true);
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto row0 = syn0({1,2, 0,0}, true);
ASSERT_EQ(exp0, row0);
ASSERT_FALSE(syn1Neg2.equalsTo(syn1Neg, 1e-6));
delete result;
}
TEST_F(NlpTests, basic_cb_hs_test_1) {
@ -269,7 +266,7 @@ TEST_F(NlpTests, basic_cb_hs_test_1) {
sd::ops::cbow op;
auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, {}, true);
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto row_s0_0 = syn0({0,1, 0,0}, true);
auto row_s0_1 = syn0({1,2, 0,0}, true);
@ -287,7 +284,7 @@ TEST_F(NlpTests, basic_cb_hs_test_1) {
ASSERT_EQ(exp1, row_s1_5);
ASSERT_EQ(exp2, row_s1_6);
delete result;
}
TEST_F(NlpTests, basic_cb_ns_test_1) {
@ -323,7 +320,7 @@ TEST_F(NlpTests, basic_cb_ns_test_1) {
sd::ops::cbow op;
auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {1, 2, 0}, {true}, {}, true);
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto row_s0_0 = syn0({0,1, 0,0}, true);
auto row_s0_1 = syn0({1,2, 0,0}, true);
@ -339,7 +336,7 @@ TEST_F(NlpTests, basic_cb_ns_test_1) {
ASSERT_EQ(exp0, row_s0_2);
ASSERT_EQ(exp2, row_s1_6);
delete result;
}
TEST_F(NlpTests, test_sg_hs_batch_1) {
@ -372,7 +369,7 @@ TEST_F(NlpTests, test_sg_hs_batch_1) {
sd::ops::skipgram op;
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false, true}, {}, true);
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto row0 = syn0({0,1, 0,0}, true);
auto row1 = syn1({1,2, 0,0}, true);
@ -382,7 +379,7 @@ TEST_F(NlpTests, test_sg_hs_batch_1) {
ASSERT_TRUE(exp1.equalsTo(row1, 1e-6));
ASSERT_TRUE(exp2.equalsTo(row2, 1e-6));
delete result;
}
TEST_F(NlpTests, test_sg_ns_batch_1) {
@ -416,9 +413,9 @@ TEST_F(NlpTests, test_sg_ns_batch_1) {
sd::ops::skipgram op;
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {4, 5}, {false, true}, {}, true);
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
delete result;
}
TEST_F(NlpTests, test_cbow_hs_batch_1) {
@ -449,7 +446,7 @@ TEST_F(NlpTests, test_cbow_hs_batch_1) {
sd::ops::cbow op;
auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, {}, true);
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto exp0 = NDArrayFactory::create<float>('c', {1, 10});
auto exp1 = NDArrayFactory::create<float>('c', {1, 10});
@ -474,5 +471,4 @@ TEST_F(NlpTests, test_cbow_hs_batch_1) {
ASSERT_EQ(exp1, row_s1_5);
ASSERT_EQ(exp2, row_s1_6);
delete result;
}

File diff suppressed because it is too large Load Diff

View File

@ -260,9 +260,9 @@ TEST_F(RNGTests, Test_Gaussian_21) {
auto result = op.evaluate({&x0}, {}, {});
//x0.printIndexedBuffer("X0 Normal");
//x1.printIndexedBuffer("X1 Normal");
ASSERT_TRUE(result->status() == Status::OK());
auto mean = result->at(0);
auto variance = result->at(1);
ASSERT_TRUE(result.status() == Status::OK());
auto mean = result.at(0);
auto variance = result.at(1);
// mean->printIndexedBuffer("Mean");
// variance->printIndexedBuffer("Variance");
@ -270,7 +270,7 @@ TEST_F(RNGTests, Test_Gaussian_21) {
ASSERT_NEAR(sd::math::nd4j_abs(mean->e<float>(0)), 0.f, 0.2f);
ASSERT_NEAR(variance->e<float>(0), 1.0f, 0.2f);
delete result;
}
#ifdef DEBUG_BUILD
@ -292,15 +292,15 @@ TEST_F(RNGTests, Test_Gaussian_22) {
auto result = op.evaluate({&x0}, {}, {});
//x0.printIndexedBuffer("X0 Normal");
//x1.printIndexedBuffer("X1 Normal");
ASSERT_TRUE(result->status() == Status::OK());
auto mean0 = result->at(0);
auto variance0 = result->at(1);
ASSERT_TRUE(result.status() == Status::OK());
auto mean0 = result.at(0);
auto variance0 = result.at(1);
//mean0->printIndexedBuffer("Mean");
//variance0->printIndexedBuffer("Variance");
ASSERT_NEAR(sd::math::nd4j_abs(mean0->e<float>(0)), 0.f, 1.0e-3f);
ASSERT_NEAR(variance0->e<float>(0), 1.0f, 1.e-3f);
delete result;
}
TEST_F(RNGTests, Test_Gaussian_3) {
@ -413,18 +413,17 @@ TEST_F(RNGTests, Test_Truncated_21) {
ASSERT_NEAR(deviation.e<float>(0), 2.f, 0.5);
sd::ops::moments op;
auto result = op.evaluate({&x0}, {}, {}, {}, {}, false);
// result->at(0)->printBuffer("MEAN");
// result->at(1)->printBuffer("VARIANCE");
delete result;
// result.at(0)->printBuffer("MEAN");
// result.at(1)->printBuffer("VARIANCE");
sd::ops::reduce_min minOp;
sd::ops::reduce_max maxOp;
auto minRes = minOp.evaluate({&x1}, {}, {}, {});
auto maxRes = maxOp.evaluate({&x0}, {}, {}, {});
// minRes->at(0)->printBuffer("MIN for Truncated");
// maxRes->at(0)->printBuffer("MAX for Truncated");
delete minRes;
delete maxRes;
}
TEST_F(RNGTests, Test_Truncated_22) {
@ -460,18 +459,15 @@ TEST_F(RNGTests, Test_Truncated_22) {
ASSERT_NEAR(deviation.e<float>(0), 4.f, 0.52);
sd::ops::moments op;
auto result = op.evaluate({&x0}, {}, {}, {}, {}, false);
// result->at(0)->printBuffer("MEAN");
// result->at(1)->printBuffer("VARIANCE");
delete result;
sd::ops::reduce_min minOp;
sd::ops::reduce_max maxOp;
auto minRes = minOp.evaluate({&x1}, {}, {}, {});
auto maxRes = maxOp.evaluate({&x0}, {}, {}, {});
// minRes->at(0)->printBuffer("MIN for Truncated2");
// maxRes->at(0)->printBuffer("MAX for Truncated2");
delete minRes;
delete maxRes;
}
TEST_F(RNGTests, Test_Truncated_23) {
@ -509,16 +505,14 @@ TEST_F(RNGTests, Test_Truncated_23) {
auto result = op.evaluate({&x0});
// result->at(0)->printBuffer("MEAN");
// result->at(1)->printBuffer("VARIANCE");
delete result;
sd::ops::reduce_min minOp;
sd::ops::reduce_max maxOp;
auto minRes = minOp.evaluate({&x1}, {}, {}, {});
auto maxRes = maxOp.evaluate({&x0}, {}, {}, {});
// minRes->at(0)->printBuffer("MIN for Truncated3");
// maxRes->at(0)->printBuffer("MAX for Truncated3");
delete minRes;
delete maxRes;
}
TEST_F(RNGTests, Test_Truncated_3) {
@ -568,15 +562,15 @@ TEST_F(RNGTests, Test_Uniform_2) {
auto op = new sd::ops::LegacyRandomOp(0);
auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(x1.isSameShape(z));
ASSERT_TRUE(x1.equalsTo(z));
delete op;
delete result;
}
TEST_F(RNGTests, Test_Gaussian_2) {
@ -588,15 +582,15 @@ TEST_F(RNGTests, Test_Gaussian_2) {
auto op = new sd::ops::LegacyRandomOp(random::GaussianDistribution);
auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(x1.isSameShape(z));
ASSERT_TRUE(x1.equalsTo(z));
delete op;
delete result;
}
TEST_F(RNGTests, Test_LogNorm_2) {
@ -608,15 +602,15 @@ TEST_F(RNGTests, Test_LogNorm_2) {
auto op = new sd::ops::LegacyRandomOp(random::LogNormalDistribution);
auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(x1.isSameShape(z));
ASSERT_TRUE(x1.equalsTo(z));
delete op;
delete result;
}
TEST_F(RNGTests, Test_TruncatedNorm_2) {
@ -628,14 +622,14 @@ TEST_F(RNGTests, Test_TruncatedNorm_2) {
auto op = new sd::ops::LegacyRandomOp(random::TruncatedNormalDistribution);
auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(x1.isSameShape(z));
ASSERT_TRUE(x1.equalsTo(z));
delete op;
delete result;
}
@ -648,15 +642,15 @@ TEST_F(RNGTests, Test_Binomial_2) {
auto op = new sd::ops::LegacyRandomOp(random::BinomialDistributionEx);
auto result = op->execute(_rngA, {&input}, {0.5f}, {3});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(x1.isSameShape(z));
ASSERT_TRUE(x1.equalsTo(z));
delete op;
delete result;
}
@ -669,15 +663,15 @@ TEST_F(RNGTests, Test_Bernoulli_2) {
auto op = new sd::ops::LegacyRandomOp(random::BernoulliDistribution);
auto result = op->execute(_rngA, {&input}, {0.5f}, {});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(x1.isSameShape(z));
ASSERT_TRUE(x1.equalsTo(z));
delete op;
delete result;
}
TEST_F(RNGTests, Test_GaussianDistribution_1) {
@ -687,9 +681,9 @@ TEST_F(RNGTests, Test_GaussianDistribution_1) {
sd::ops::random_normal op;
auto result = op.evaluate({&x}, {0.0, 1.0f}, {});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp0.isSameShape(z));
ASSERT_FALSE(exp0.equalsTo(z));
@ -698,7 +692,7 @@ TEST_F(RNGTests, Test_GaussianDistribution_1) {
ASSERT_FALSE(nexp1->equalsTo(z));
ASSERT_FALSE(nexp2->equalsTo(z));
delete result;
}
TEST_F(RNGTests, Test_BernoulliDistribution_1) {
@ -708,9 +702,9 @@ TEST_F(RNGTests, Test_BernoulliDistribution_1) {
sd::ops::random_bernoulli op;
auto result = op.evaluate({&x}, {0.5f}, {});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_FALSE(exp0.equalsTo(z));
@ -718,7 +712,7 @@ TEST_F(RNGTests, Test_BernoulliDistribution_1) {
ASSERT_FALSE(nexp1->equalsTo(z));
ASSERT_FALSE(nexp2->equalsTo(z));
delete result;
}
@ -729,9 +723,9 @@ TEST_F(RNGTests, Test_ExponentialDistribution_1) {
sd::ops::random_exponential op;
auto result = op.evaluate({&x}, {0.25f}, {0});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp0.isSameShape(z));
ASSERT_FALSE(exp0.equalsTo(z));
@ -740,7 +734,7 @@ TEST_F(RNGTests, Test_ExponentialDistribution_1) {
ASSERT_FALSE(nexp1->equalsTo(z));
ASSERT_FALSE(nexp2->equalsTo(z));
delete result;
}
TEST_F(RNGTests, Test_ExponentialDistribution_2) {
@ -753,9 +747,9 @@ TEST_F(RNGTests, Test_ExponentialDistribution_2) {
sd::ops::random_exponential op;
auto result = op.evaluate({&x, &y}, {0.25f}, {0});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp0.isSameShape(z));
ASSERT_FALSE(exp0.equalsTo(z));
@ -764,7 +758,7 @@ TEST_F(RNGTests, Test_ExponentialDistribution_2) {
ASSERT_FALSE(nexp1->equalsTo(z));
ASSERT_FALSE(nexp2->equalsTo(z));
delete result;
}
TEST_F(RNGTests, Test_PoissonDistribution_1) {
@ -777,14 +771,14 @@ TEST_F(RNGTests, Test_PoissonDistribution_1) {
sd::ops::random_poisson op;
auto result = op.evaluate({&x, &la}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
// z->printIndexedBuffer("Poisson distribution");
ASSERT_TRUE(exp0.isSameShape(z));
ASSERT_FALSE(exp0.equalsTo(z));
delete result;
}
TEST_F(RNGTests, Test_GammaDistribution_1) {
@ -797,14 +791,14 @@ TEST_F(RNGTests, Test_GammaDistribution_1) {
sd::ops::random_gamma op;
auto result = op.evaluate({&x, &al}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
// z->printIndexedBuffer("Gamma distribution");
ASSERT_TRUE(exp0.isSameShape(z));
ASSERT_FALSE(exp0.equalsTo(z));
delete result;
}
TEST_F(RNGTests, Test_GammaDistribution_2) {
@ -818,14 +812,14 @@ TEST_F(RNGTests, Test_GammaDistribution_2) {
sd::ops::random_gamma op;
auto result = op.evaluate({&x, &al, &be}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
// z->printIndexedBuffer("Gamma distribution");
ASSERT_TRUE(exp0.isSameShape(z));
ASSERT_FALSE(exp0.equalsTo(z));
delete result;
}
TEST_F(RNGTests, Test_GammaDistribution_3) {
@ -839,14 +833,14 @@ TEST_F(RNGTests, Test_GammaDistribution_3) {
sd::ops::random_gamma op;
auto result = op.evaluate({&x, &al, &be}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
// z->printIndexedBuffer("Gamma distribution");
ASSERT_TRUE(exp0.isSameShape(z));
ASSERT_FALSE(exp0.equalsTo(z));
delete result;
}
TEST_F(RNGTests, Test_UniformDistribution_04) {
@ -858,13 +852,13 @@ TEST_F(RNGTests, Test_UniformDistribution_04) {
sd::ops::randomuniform op;
auto result = op.evaluate({&x, &al, &be}, {}, {DataType::INT32});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp0.isSameShape(z));
ASSERT_FALSE(exp0.equalsTo(z));
delete result;
}
namespace sd {
@ -1021,12 +1015,12 @@ TEST_F(RNGTests, test_multinomial_1) {
NDArray expectedZ('c', { 3, 3 }, { 0., 0, 0, 0, 0, 0, 0, 0, 0 }, sd::DataType::INT64);
auto result = op.evaluate({ &probsZ, &samples }, { }, { 1, INT64 });
auto outputZ = result->at(0);
auto outputZ = result.at(0);
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
ASSERT_TRUE(expectedZ.isSameShape(outputZ));
ASSERT_TRUE(expectedZ.equalsTo(outputZ));
delete result;
}
TEST_F(RNGTests, test_multinomial_2) {
@ -1117,8 +1111,8 @@ TEST_F(RNGTests, test_multinomial_5) {
}
auto resultR = op.evaluate({ &probs, &samples }, { }, { 1 });
auto outputR = resultR->at(0);
ASSERT_EQ(Status::OK(), resultR->status());
auto outputR = resultR.at(0);
ASSERT_EQ(Status::OK(), resultR.status());
deviation = outputR->varianceNumber(variance::SummaryStatsStandardDeviation, false);
mean = outputR->meanNumber();
@ -1131,7 +1125,6 @@ TEST_F(RNGTests, test_multinomial_5) {
ASSERT_TRUE(value >= 0 && value < ClassValue);
}
delete resultR;
}
@ -1150,8 +1143,8 @@ TEST_F(RNGTests, test_multinomial_6) {
NDArray probsR('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, sd::DataType::FLOAT32);
auto resultR = op.evaluate({ &probsR, &samples }, { }, { 0 });
auto outputR = resultR->at(0);
ASSERT_EQ(Status::OK(), resultR->status());
auto outputR = resultR.at(0);
ASSERT_EQ(Status::OK(), resultR.status());
NDArray countsR('c', { ClassValue }, { 0., 0, 0, 0, 0 }, sd::DataType::DOUBLE);
@ -1175,7 +1168,7 @@ TEST_F(RNGTests, test_multinomial_6) {
ASSERT_NEAR(1.2175, deviation.e<double>(0), 45e-3); // 1000000 35e-3);
ASSERT_NEAR(2.906, mean.e<double>(0), 45e-3); // 1000000 35e-3);
delete resultR;
RandomGenerator rng(1234, 1234);
NDArray probs('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, sd::DataType::FLOAT32);

View File

@ -96,14 +96,14 @@ TEST_F(ScalarTests, Test_Concat_1) {
sd::ops::concat op;
auto result = op.evaluate({&t, &u, &v}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
@ -116,15 +116,15 @@ TEST_F(ScalarTests, Test_Concat_2) {
sd::ops::concat op;
auto result = op.evaluate({&t, &u, &v}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
// z->printIndexedBuffer();
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
@ -137,16 +137,16 @@ TEST_F(ScalarTests, Test_Concat_3) {
sd::ops::concat op;
auto result = op.evaluate({&t, &u, &v}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
//z->printShapeInfo("z");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(ScalarTests, Test_ExpandDims_1) {
@ -156,14 +156,14 @@ TEST_F(ScalarTests, Test_ExpandDims_1) {
sd::ops::expand_dims op;
auto result = op.evaluate({&x}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(ScalarTests, Test_Squeeze_1) {
@ -172,14 +172,14 @@ TEST_F(ScalarTests, Test_Squeeze_1) {
sd::ops::squeeze op;
auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
@ -189,14 +189,14 @@ TEST_F(ScalarTests, Test_Reshape_1) {
sd::ops::reshape op;
auto result = op.evaluate({&x}, {}, {-99, 1, 1, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
@ -206,14 +206,14 @@ TEST_F(ScalarTests, Test_Permute_1) {
sd::ops::permute op;
auto result = op.evaluate({&x}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(ScalarTests, Test_Concat_Scalar_1) {
@ -225,14 +225,13 @@ TEST_F(ScalarTests, Test_Concat_Scalar_1) {
sd::ops::concat op;
auto result = op.evaluate({&t, &u, &v, &w}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
@ -245,12 +244,11 @@ TEST_F(ScalarTests, Test_Concat_Scalar_2) {
sd::ops::concat op;
auto result = op.evaluate({&t, &u, &v, &w}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}

View File

@ -308,14 +308,13 @@ TEST_F(ShapeTests, Tests_Transpose_119_2) {
sd::ops::transpose op;
auto result = op.evaluate({&x});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(Status::OK(), result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(ShapeTests, Tests_Transpose_119_3) {

View File

@ -70,14 +70,14 @@ TEST_F(SingleDimTests, Test_Concat_1) {
sd::ops::concat op;
auto result = op.evaluate({&x, &y}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(SingleDimTests, Test_Reduce_1) {
@ -104,14 +104,14 @@ TEST_F(SingleDimTests, Test_ExpandDims_1) {
sd::ops::expand_dims op;
auto result = op.evaluate({&x}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
@ -122,14 +122,14 @@ TEST_F(SingleDimTests, Test_ExpandDims_2) {
sd::ops::expand_dims op;
auto result = op.evaluate({&x}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
@ -142,14 +142,14 @@ TEST_F(SingleDimTests, Test_Squeeze_1) {
sd::ops::squeeze op;
auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_EQ(exp.rankOf(), z->rankOf());
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(SingleDimTests, Test_Squeeze_2) {
@ -158,14 +158,14 @@ TEST_F(SingleDimTests, Test_Squeeze_2) {
sd::ops::squeeze op;
auto result = op.evaluate({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(SingleDimTests, Test_Reshape_1) {
@ -174,14 +174,14 @@ TEST_F(SingleDimTests, Test_Reshape_1) {
sd::ops::reshape op;
auto result = op.evaluate({&x}, {}, {-99, 3});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(SingleDimTests, Test_Reshape_2) {
@ -190,14 +190,14 @@ TEST_F(SingleDimTests, Test_Reshape_2) {
sd::ops::reshape op;
auto result = op.evaluate({&x}, {}, {-99, 1, 3});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
@ -207,12 +207,12 @@ TEST_F(SingleDimTests, Test_Permute_1) {
sd::ops::permute op;
auto result = op.evaluate({&x}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result->at(0);
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}