Improve ResultSet usage in libnd4j (#281)
* libnd4j profiling DeclarableOp and Tests by replacing return ResultSet pointer by instance Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j profiling semantic change in tests cases Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j some corrections to make new ResultSet semantic works, fixed one test Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j more tests fixes Signed-off-by: Oleg <oleg.semeniv@gmail.com> * - correct copy and move assignment operators of ResultSet class Signed-off-by: Yurii <iuriish@yahoo.com> Co-authored-by: Yurii <iuriish@yahoo.com> Co-authored-by: raver119 <raver119@gmail.com>master
parent
57210b936c
commit
c3223dbc7a
|
@ -15,8 +15,8 @@
|
|||
******************************************************************************/
|
||||
|
||||
//
|
||||
// This class is suited for execution results representation.
|
||||
//
|
||||
// This class is suited for execution results representation.
|
||||
//
|
||||
// PLESE NOTE: It will delete all stored NDArrays upon destructor call
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
|
@ -33,13 +33,15 @@
|
|||
namespace sd {
|
||||
|
||||
class NDArray; // forward declaration of template class NDArray
|
||||
|
||||
|
||||
class ND4J_EXPORT ResultSet {
|
||||
private:
|
||||
std::vector<sd::NDArray *> _content;
|
||||
Nd4jStatus _status = ND4J_STATUS_OK;
|
||||
bool _removable = true;
|
||||
|
||||
void delContent();
|
||||
|
||||
public:
|
||||
explicit ResultSet();
|
||||
|
||||
|
@ -56,7 +58,7 @@ namespace sd {
|
|||
|
||||
// move assignment operator
|
||||
ResultSet& operator=(ResultSet&& other) noexcept;
|
||||
|
||||
|
||||
~ResultSet();
|
||||
|
||||
int size();
|
||||
|
|
|
@ -160,9 +160,7 @@ namespace sd {
|
|||
|
||||
auto result = op.evaluate(inputs);
|
||||
|
||||
auto array = new NDArray(result->at(0)->dup());
|
||||
|
||||
delete result;
|
||||
auto array = new NDArray(result.at(0)->dup());
|
||||
|
||||
return array;
|
||||
}
|
||||
|
|
|
@ -77,15 +77,16 @@ namespace sd {
|
|||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
// move assignment operator
|
||||
// move assignment operator
|
||||
ResultSet& ResultSet::operator=(ResultSet&& other) noexcept {
|
||||
|
||||
if (this == &other)
|
||||
if (this == &other)
|
||||
return *this;
|
||||
|
||||
this->~ResultSet();
|
||||
delContent();
|
||||
|
||||
_content = std::move(other._content);
|
||||
|
||||
_status = other._status;
|
||||
_removable = other._removable;
|
||||
other._removable = false;
|
||||
|
@ -98,10 +99,10 @@ namespace sd {
|
|||
if (this == &other)
|
||||
return *this;
|
||||
|
||||
this->~ResultSet();
|
||||
delContent();
|
||||
|
||||
for (const auto v:other._content)
|
||||
_content.emplace_back(v);
|
||||
for (const auto v : other._content)
|
||||
_content.push_back(v);
|
||||
|
||||
_status = other._status;
|
||||
_removable = false;
|
||||
|
@ -109,11 +110,15 @@ namespace sd {
|
|||
return *this;
|
||||
}
|
||||
|
||||
void ResultSet::delContent() {
|
||||
if (_removable)
|
||||
for (auto v : _content)
|
||||
delete v;
|
||||
}
|
||||
|
||||
ResultSet::~ResultSet() {
|
||||
if (_removable)
|
||||
for (auto v: _content)
|
||||
delete v;
|
||||
|
||||
delContent();
|
||||
}
|
||||
|
||||
void ResultSet::setNonRemovable() {
|
||||
|
|
|
@ -60,7 +60,7 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
|
|||
fillGradArrays(loss, std::vector<NDArray*>(&inArrsBP[numInArrsFF], &inArrsBP[numInArrsFF + numInGradArrsBP]));
|
||||
|
||||
// back prop pass
|
||||
ResultSet* outArrsBP = opBP.execute(argsHolderBP); // number of output arrays in back prop = numInArrsFF;
|
||||
ResultSet outArrsBP = opBP.execute(argsHolderBP); // number of output arrays in back prop = numInArrsFF;
|
||||
|
||||
NDArray tmpScalar(sd::DataType::DOUBLE, inArrsFF[0]->getContext()); // scalar = 0
|
||||
|
||||
|
@ -78,18 +78,17 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
|
|||
|
||||
// add epsilon, feed forward
|
||||
inArrsFF[i]->p<double>(j, orig + EPSILON);
|
||||
ResultSet* outArrsFF = opFF.execute(argsHolderFF);
|
||||
int numOutArrs = outArrsFF->size();
|
||||
ResultSet outArrsFF = opFF.execute(argsHolderFF);
|
||||
int numOutArrs = outArrsFF.size();
|
||||
double scorePlus = 0.;
|
||||
|
||||
for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays
|
||||
if(loss == SUM)
|
||||
outArrsFF->at(k)->reduceNumber(reduce::Sum, tmpScalar);
|
||||
outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar);
|
||||
else
|
||||
outArrsFF->at(k)->reduceNumber(reduce::Mean, tmpScalar);
|
||||
outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar);
|
||||
scorePlus += tmpScalar.e<double>(0);
|
||||
}
|
||||
delete outArrsFF;
|
||||
|
||||
// subtract epsilon, feed forward
|
||||
inArrsFF[i]->p<double>(j, orig - EPSILON);
|
||||
|
@ -98,12 +97,11 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
|
|||
|
||||
for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays
|
||||
if(loss == SUM)
|
||||
outArrsFF->at(k)->reduceNumber(reduce::Sum, tmpScalar);
|
||||
outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar);
|
||||
else
|
||||
outArrsFF->at(k)->reduceNumber(reduce::Mean, tmpScalar);
|
||||
outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar);
|
||||
scoreMinus += tmpScalar.e<double>(0);
|
||||
}
|
||||
delete outArrsFF;
|
||||
|
||||
// restore initial element value
|
||||
inArrsFF[i]->p<double>(j, orig);
|
||||
|
@ -116,7 +114,7 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
|
|||
}
|
||||
|
||||
// get analytical gradient
|
||||
const double analyticGrad = outArrsBP->at(i)->e<double>(j);
|
||||
const double analyticGrad = outArrsBP.at(i)->e<double>(j);
|
||||
if(std::isnan(analyticGrad) || std::isinf(analyticGrad)) {
|
||||
printf("GradCheck::checkGrad: got wrong value for analytical gradient for input array # %i and its element at position %lld ! \n", i, j);
|
||||
throw std::runtime_error("");
|
||||
|
@ -138,13 +136,11 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
|
|||
continue;
|
||||
printf("numericalGrad = %f, analyticGrad = %f \n", numericalGrad, analyticGrad);
|
||||
printf("GradCheck::checkGrad: got RELERROR = %f > MAXRELERROR(%f) for input array # %i and its element at position %lld ! \n", relError, MAXRELERR, i, j);
|
||||
delete outArrsBP;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
delete outArrsBP;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -45,8 +45,8 @@ namespace sd {
|
|||
Nd4jStatus execute(Context* block) override;
|
||||
|
||||
|
||||
ResultSet* execute(NDArrayList* list, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs);
|
||||
ResultSet* execute(NDArrayList* list, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs);
|
||||
ResultSet execute(NDArrayList* list, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs);
|
||||
ResultSet execute(NDArrayList* list, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs);
|
||||
|
||||
ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override;
|
||||
};
|
||||
|
|
|
@ -176,17 +176,17 @@ namespace sd {
|
|||
|
||||
Nd4jStatus execute(const std::vector<NDArray*> &inputs, const std::vector<NDArray*> &outputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs = std::vector<bool>(), const std::vector<sd::DataType> &dArgs = std::vector<sd::DataType>(), bool isInplace = false);
|
||||
|
||||
|
||||
sd::ResultSet* evaluate(const std::vector<NDArray*> &inputs);
|
||||
sd::ResultSet evaluate(const std::vector<NDArray*> &inputs);
|
||||
|
||||
template <class T, typename = std::enable_if<DataTypeUtils::scalarTypesForExecution<T>::value>>
|
||||
sd::ResultSet* evaluate(const std::vector<NDArray*> &inputs, std::initializer_list<T> args);
|
||||
sd::ResultSet evaluate(const std::vector<NDArray*> &inputs, std::initializer_list<T> args);
|
||||
|
||||
sd::ResultSet* evaluate(const std::vector<NDArray*> &inputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs = std::vector<bool>(), const std::vector<sd::DataType> &dArgs = std::vector<sd::DataType>(), bool isInplace = false);
|
||||
sd::ResultSet evaluate(const std::vector<NDArray*> &inputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs = std::vector<bool>(), const std::vector<sd::DataType> &dArgs = std::vector<sd::DataType>(), bool isInplace = false);
|
||||
|
||||
Nd4jStatus execute(sd::graph::RandomGenerator& rng, const std::vector<NDArray*>& inputs, const std::vector<NDArray*>& outputs, const std::vector<double>& tArgs, const std::vector<Nd4jLong>& iArgs, const std::vector<bool>& bArgs, const std::vector<sd::DataType> &dArgs = std::vector<sd::DataType>(), bool isInplace = false, sd::DataType type = sd::DataType::FLOAT32);
|
||||
|
||||
sd::ResultSet* execute(const sd::OpArgsHolder& holder, bool isInplace = false);
|
||||
sd::ResultSet execute(const sd::OpArgsHolder& holder, bool isInplace = false);
|
||||
|
||||
|
||||
// There methods provide various validation options
|
||||
Nd4jStatus validateNonEmptyInput(Context& block);
|
||||
|
|
|
@ -41,8 +41,9 @@ namespace sd {
|
|||
template <typename T>
|
||||
Nd4jStatus validateAndExecute_(Context &block);
|
||||
|
||||
sd::ResultSet* execute(sd::graph::RandomGenerator& rng, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs, bool isInplace = false);
|
||||
sd::ResultSet* execute(sd::graph::RandomGenerator& rng, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs, bool isInplace = false);
|
||||
sd::ResultSet execute(sd::graph::RandomGenerator& rng, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs, bool isInplace = false);
|
||||
sd::ResultSet execute(sd::graph::RandomGenerator& rng, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs, bool isInplace = false);
|
||||
|
||||
Nd4jStatus execute(Context* block) override;
|
||||
|
||||
ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override;
|
||||
|
|
|
@ -74,10 +74,10 @@ namespace sd {
|
|||
// at first step we build fwd activation
|
||||
sd::ops::crelu op;
|
||||
auto tmpResult = op.evaluate({input});
|
||||
if (tmpResult->status() != ND4J_STATUS_OK)
|
||||
return tmpResult->status();
|
||||
if (tmpResult.status() != ND4J_STATUS_OK)
|
||||
return tmpResult.status();
|
||||
|
||||
auto actv = tmpResult->at(0);
|
||||
auto actv = tmpResult.at(0);
|
||||
|
||||
// now we do RELU backward pass
|
||||
//actv->applyPairwiseTransform(pairwise::RELUDerivativeE, *epsilon, nullptr);
|
||||
|
@ -85,17 +85,15 @@ namespace sd {
|
|||
// now we split updated array into 2 chunks along last dimension
|
||||
sd::ops::concat_bp opc;
|
||||
auto dec = opc.evaluate({input, input, actv}, {-1});
|
||||
if (dec->status() != ND4J_STATUS_OK)
|
||||
return dec->status();
|
||||
if (dec.status() != ND4J_STATUS_OK)
|
||||
return dec.status();
|
||||
|
||||
// and now we subtract two parts of epsilons and pass result out
|
||||
auto pos = dec->at(0);
|
||||
auto neg = dec->at(1);
|
||||
auto pos = dec.at(0);
|
||||
auto neg = dec.at(1);
|
||||
|
||||
pos->applyPairwiseTransform(sd::pairwise::Subtract, *neg, *epsilon);
|
||||
|
||||
delete tmpResult;
|
||||
delete dec;
|
||||
return ND4J_STATUS_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -102,10 +102,12 @@ namespace sd {
|
|||
REQUIRE_TRUE(block.width() == 1, 0, "Where op takes either 1 or 3 operands, But got %d operands instead", block.width());
|
||||
// if (output->isEmpty())
|
||||
Nd4jLong width = condition->rankOf();
|
||||
|
||||
sd::ops::Where op;
|
||||
std::unique_ptr<ResultSet> res(op.evaluate({condition}));
|
||||
REQUIRE_OK(res->status());
|
||||
NDArray* whereTrue = res->at(0);
|
||||
auto res(op.evaluate({condition}));
|
||||
REQUIRE_OK(res.status());
|
||||
NDArray* whereTrue = res.at(0);
|
||||
|
||||
if (whereTrue->isEmpty())
|
||||
return ND4J_STATUS_OK;
|
||||
for (Nd4jLong outNext = 0; outNext < width; ++outNext) {
|
||||
|
|
|
@ -65,11 +65,12 @@ namespace sd {
|
|||
auto gradX = OUTPUT_VARIABLE(0);
|
||||
auto gradY = OUTPUT_VARIABLE(1);
|
||||
gradX->assign(epsNext);
|
||||
|
||||
sd::ops::floormod op;
|
||||
std::unique_ptr<ResultSet> tmpResult(op.evaluate({x, y}));
|
||||
auto tmpResult(op.evaluate({x, y}));
|
||||
|
||||
if (gradY->rankOf() == gradX->rankOf())
|
||||
epsNext->applyPairwiseTransform(pairwise::Multiply, *tmpResult->at(0), *gradY);
|
||||
epsNext->applyPairwiseTransform(pairwise::Multiply, *tmpResult.at(0), *gradY);
|
||||
else // epsNext is greater than gradY
|
||||
{
|
||||
std::vector<Nd4jLong> dims(epsNext->rankOf() * 2);
|
||||
|
@ -77,7 +78,7 @@ namespace sd {
|
|||
for (Nd4jLong d = 0; d < gap; d++) {
|
||||
dims[d * 2 + 1] = 1;
|
||||
}
|
||||
auto tempIn((*tmpResult->at(0))(dims));
|
||||
auto tempIn((*tmpResult.at(0))(dims));
|
||||
(*epsNext)(dims).applyPairwiseTransform(pairwise::Multiply, tempIn, *gradY);
|
||||
}
|
||||
return Status::OK();
|
||||
|
|
|
@ -113,23 +113,21 @@ namespace ops {
|
|||
originalIndices.linspace(0);
|
||||
ops::dynamic_partition op;
|
||||
auto res = op.evaluate({&originalIndices, indices}, {numPartition});
|
||||
REQUIRE_TRUE(res->status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning.");
|
||||
REQUIRE_TRUE(res.status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning.");
|
||||
ops::dynamic_stitch stichOp;
|
||||
std::vector<NDArray*> partitions(numPartition * 2);
|
||||
for (size_t i = 0; i < res->size(); i++) {
|
||||
partitions[i] = res->at(i);
|
||||
for (size_t i = 0; i < res.size(); i++) {
|
||||
partitions[i] = res.at(i);
|
||||
partitions[i + numPartition] = gradOutList[i];
|
||||
}
|
||||
|
||||
auto result = stichOp.evaluate(partitions, {numPartition});
|
||||
REQUIRE_TRUE(result->status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning.");
|
||||
result->at(0)->reshapei(outputList[0]->getShapeAsVector());
|
||||
REQUIRE_TRUE(result.status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning.");
|
||||
result.at(0)->reshapei(outputList[0]->getShapeAsVector());
|
||||
outputList[1]->assign(indices);
|
||||
outputList[0]->assign(result->at(0));
|
||||
outputList[0]->assign(result.at(0));
|
||||
|
||||
// helpers::dynamicPartitionFunctorBP(block.launchContext(), input, indices, gradOutList, outputList);
|
||||
delete res;
|
||||
delete result;
|
||||
return ND4J_STATUS_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -66,10 +66,10 @@ CUSTOM_OP_IMPL(embedding_lookup, 2, 1, false, 0, 1) {
|
|||
|
||||
sd::ops::gather op;
|
||||
|
||||
std::unique_ptr<ResultSet> result(op.evaluate({input, indeces}, {0}));
|
||||
REQUIRE_TRUE(result->status() == Status::OK(), 0, "embedding_lookup: cannot retrieve results from gather op.");
|
||||
REQUIRE_TRUE(result->at(0)->isSameShape(output), 0, "embedding_lookup: wrong shape of return from gather op.");
|
||||
output->assign(result->at(0));
|
||||
auto result(op.evaluate({input, indeces}, {0}));
|
||||
REQUIRE_TRUE(result.status() == Status::OK(), 0, "embedding_lookup: cannot retrieve results from gather op.");
|
||||
REQUIRE_TRUE(result.at(0)->isSameShape(output), 0, "embedding_lookup: wrong shape of return from gather op.");
|
||||
output->assign(result.at(0));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -95,8 +95,8 @@ CUSTOM_OP_IMPL(dynamic_bidirectional_rnn, 7, 4, false, 0, 0) {
|
|||
// forward steps
|
||||
sd::ops::dynamic_rnn dynamicRnn;
|
||||
auto resultsFW = dynamicRnn.evaluate({x, WxFW, WhFW, bFW, h0FW, maxTimeStep}, {timeMajor});
|
||||
hFW->assign(resultsFW->at(0)); // [time x bS x numUnitsFW] or [bS x time x numUnitsFW]
|
||||
hFWFinal->assign(resultsFW->at(1));
|
||||
hFW->assign(resultsFW.at(0)); // [time x bS x numUnitsFW] or [bS x time x numUnitsFW]
|
||||
hFWFinal->assign(resultsFW.at(1));
|
||||
|
||||
auto seqLen = maxTimeStep;
|
||||
if(seqLen == nullptr) {
|
||||
|
@ -108,22 +108,17 @@ CUSTOM_OP_IMPL(dynamic_bidirectional_rnn, 7, 4, false, 0, 0) {
|
|||
// reverse x
|
||||
sd::ops::reverse_sequence reverse;
|
||||
auto resultsIn = timeMajor ? reverse.evaluate({x, seqLen}, {0, 1}) : reverse.evaluate({x, seqLen}, {1, 0});
|
||||
REQUIRE_TRUE (resultsIn->status() == ND4J_STATUS_OK, 0, "dynamic_bidirectional_rnn: there is a problem with reverse on the sequence.");
|
||||
auto revInput = resultsIn->at(0);
|
||||
REQUIRE_TRUE (resultsIn.status() == ND4J_STATUS_OK, 0, "dynamic_bidirectional_rnn: there is a problem with reverse on the sequence.");
|
||||
auto revInput = resultsIn.at(0);
|
||||
|
||||
// backward steps
|
||||
auto resultsBW = dynamicRnn.evaluate({revInput, WxBW, WhBW, bBW, h0BW, maxTimeStep}, {timeMajor});
|
||||
auto hBWtemp = resultsBW->at(0); // [time x bS x numUnitsBW] or [ bS x time xnumUnitsBW]
|
||||
hBWFinal->assign(resultsBW->at(1));
|
||||
auto hBWtemp = resultsBW.at(0); // [time x bS x numUnitsBW] or [ bS x time xnumUnitsBW]
|
||||
hBWFinal->assign(resultsBW.at(1));
|
||||
|
||||
// reverse hBWtemp
|
||||
auto resultsOut = timeMajor ? reverse.evaluate({hBWtemp, seqLen}, {0, 1}) : reverse.evaluate({hBWtemp, seqLen}, {1, 0});
|
||||
hBW->assign(resultsOut->at(0));
|
||||
|
||||
delete resultsOut;
|
||||
delete resultsBW;
|
||||
delete resultsIn;
|
||||
delete resultsFW;
|
||||
hBW->assign(resultsOut.at(0));
|
||||
|
||||
if(seqLen != maxTimeStep)
|
||||
delete seqLen;
|
||||
|
@ -228,12 +223,6 @@ DECLARE_SHAPE_FN(dynamic_bidirectional_rnn) {
|
|||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -52,14 +52,13 @@ namespace helpers {
|
|||
|
||||
sd::ops::unique opUnique;
|
||||
auto uResult = opUnique.evaluate({&arrayFull});
|
||||
if (Status::OK() != uResult->status())
|
||||
if (Status::OK() != uResult.status())
|
||||
throw std::runtime_error("multiUnique: cannot execute unique op properly.");
|
||||
|
||||
auto uniqueVals = uResult->at(0);
|
||||
auto uniqueVals = uResult.at(0);
|
||||
|
||||
bool res = uniqueVals->lengthOf() == arrayFull.lengthOf();
|
||||
|
||||
delete uResult;
|
||||
return res;
|
||||
}
|
||||
|
||||
|
|
|
@ -64,7 +64,7 @@ namespace sd {
|
|||
block.pushNDArrayListToVariableSpace(block.getNodeId(), 0, arrayList);
|
||||
}
|
||||
|
||||
ResultSet* DeclarableListOp::execute(NDArrayList* list, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs) {
|
||||
ResultSet DeclarableListOp::execute(NDArrayList* list, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs) {
|
||||
std::vector<NDArray*> ins(inputs);
|
||||
std::vector<double> tas(tArgs);
|
||||
std::vector<int> ias(iArgs);
|
||||
|
@ -94,7 +94,7 @@ namespace sd {
|
|||
return status;
|
||||
}
|
||||
|
||||
ResultSet* DeclarableListOp::execute(NDArrayList* list, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs) {
|
||||
ResultSet DeclarableListOp::execute(NDArrayList* list, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs) {
|
||||
VariableSpace varSpace;
|
||||
int nodeId = 119;
|
||||
|
||||
|
@ -132,8 +132,8 @@ namespace sd {
|
|||
|
||||
|
||||
Nd4jStatus result = this->validateAndExecute(block);
|
||||
auto res = new ResultSet();
|
||||
res->setStatus(result);
|
||||
ResultSet res;
|
||||
res.setStatus(result);
|
||||
|
||||
for (int e = 0; e < DataTypeUtils::max<int>(); e++) {
|
||||
std::pair<int,int> pair(1, e);
|
||||
|
@ -143,10 +143,10 @@ namespace sd {
|
|||
auto arr = var->getNDArray();
|
||||
if (arr->isAttached()) {
|
||||
auto d = arr->detach();
|
||||
res->push_back(d);
|
||||
res.push_back(d);
|
||||
} else {
|
||||
var->markRemovable(false);
|
||||
res->push_back(arr);
|
||||
res.push_back(arr);
|
||||
}
|
||||
}
|
||||
} else
|
||||
|
|
|
@ -962,12 +962,12 @@ namespace sd {
|
|||
return execute(&ctx);
|
||||
}
|
||||
|
||||
sd::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs) {
|
||||
sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs) {
|
||||
return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<sd::DataType>());
|
||||
}
|
||||
|
||||
template <>
|
||||
sd::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<int> iArgs) {
|
||||
sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<int> iArgs) {
|
||||
std::vector<Nd4jLong> realArgs;
|
||||
for (auto v:iArgs)
|
||||
realArgs.emplace_back(v);
|
||||
|
@ -976,12 +976,12 @@ namespace sd {
|
|||
}
|
||||
|
||||
template <>
|
||||
sd::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<Nd4jLong> iArgs) {
|
||||
sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<Nd4jLong> iArgs) {
|
||||
return evaluate(inputs, std::vector<double>(), iArgs, std::vector<bool>(), std::vector<sd::DataType>());
|
||||
}
|
||||
|
||||
template <>
|
||||
sd::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<float> tArgs) {
|
||||
sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<float> tArgs) {
|
||||
std::vector<double> realArgs;
|
||||
for (auto v:tArgs)
|
||||
realArgs.emplace_back(v);
|
||||
|
@ -990,21 +990,21 @@ namespace sd {
|
|||
}
|
||||
|
||||
template <>
|
||||
sd::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<double> tArgs) {
|
||||
sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<double> tArgs) {
|
||||
return evaluate(inputs, tArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<sd::DataType>());
|
||||
}
|
||||
|
||||
template <>
|
||||
sd::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<bool> bArgs) {
|
||||
sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<bool> bArgs) {
|
||||
return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), bArgs, std::vector<sd::DataType>());
|
||||
}
|
||||
|
||||
template <>
|
||||
sd::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<sd::DataType> bArgs) {
|
||||
sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<sd::DataType> bArgs) {
|
||||
return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), std::vector<bool>(), bArgs);
|
||||
}
|
||||
|
||||
sd::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs, const std::vector<sd::DataType> &dArgs, bool isInplace) {
|
||||
sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs, const std::vector<sd::DataType> &dArgs, bool isInplace) {
|
||||
VariableSpace variableSpace;
|
||||
//ResultSet arrayList;
|
||||
FlowPath fp;
|
||||
|
@ -1041,11 +1041,11 @@ namespace sd {
|
|||
block.getDArguments()->push_back(dArgs.at(e));
|
||||
|
||||
Nd4jStatus status = this->execute(&block);
|
||||
auto arrayList = new ResultSet();
|
||||
ResultSet arrayList;
|
||||
if (isInplace)
|
||||
arrayList->setNonRemovable();
|
||||
arrayList.setNonRemovable();
|
||||
|
||||
arrayList->setStatus(status);
|
||||
arrayList.setStatus(status);
|
||||
if (status != ND4J_STATUS_OK)
|
||||
return arrayList;
|
||||
|
||||
|
@ -1058,23 +1058,23 @@ namespace sd {
|
|||
if (!arr->isAttached()) {
|
||||
var->markRemovable(false);
|
||||
arr->setContext(sd::LaunchContext::defaultContext());
|
||||
arrayList->push_back(arr);
|
||||
arrayList.push_back(arr);
|
||||
} else {
|
||||
arrayList->push_back(arr->detach());
|
||||
arrayList.push_back(arr->detach());
|
||||
}
|
||||
} else
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
for (auto v:inputs) {
|
||||
arrayList->push_back(v);
|
||||
arrayList.push_back(v);
|
||||
}
|
||||
}
|
||||
|
||||
return arrayList;
|
||||
}
|
||||
|
||||
sd::ResultSet* sd::ops::DeclarableOp::execute(const sd::OpArgsHolder& holder, bool isInplace) {
|
||||
sd::ResultSet sd::ops::DeclarableOp::execute(const sd::OpArgsHolder& holder, bool isInplace) {
|
||||
// FIXME: add DArgs to OpArgsHolder
|
||||
return evaluate(holder.getInArrs(), holder.getTArgs(), holder.getIArgs(), holder.getBArgs(), std::vector<sd::DataType>(), isInplace);
|
||||
}
|
||||
|
|
|
@ -357,20 +357,20 @@ namespace sd {
|
|||
return DeclarableOp::execute(block);
|
||||
}
|
||||
|
||||
sd::ResultSet* LegacyRandomOp::execute(sd::graph::RandomGenerator& rng, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs, bool isInplace) {
|
||||
sd::ResultSet LegacyRandomOp::execute(sd::graph::RandomGenerator& rng, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs, bool isInplace) {
|
||||
std::vector<NDArray*> ins(inputs);
|
||||
std::vector<double> tas(tArgs);
|
||||
std::vector<int> ias(iArgs);
|
||||
return this->execute(rng, ins, tas, ias, isInplace);
|
||||
}
|
||||
|
||||
sd::ResultSet* LegacyRandomOp::execute(sd::graph::RandomGenerator& rng, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs, bool isInplace) {
|
||||
sd::ResultSet LegacyRandomOp::execute(sd::graph::RandomGenerator& rng, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs, bool isInplace) {
|
||||
VariableSpace variableSpace;
|
||||
auto arrayList = new ResultSet();
|
||||
ResultSet arrayList;
|
||||
//ResultSet arrayList;
|
||||
|
||||
if (isInplace)
|
||||
arrayList->setNonRemovable();
|
||||
arrayList.setNonRemovable();
|
||||
|
||||
int cnt = -1;
|
||||
std::vector<int> in;
|
||||
|
@ -398,7 +398,7 @@ namespace sd {
|
|||
block.getIArguments()->emplace_back(iArgs.at(e));
|
||||
|
||||
Nd4jStatus status = this->execute(&block);
|
||||
arrayList->setStatus(status);
|
||||
arrayList.setStatus(status);
|
||||
if (status != ND4J_STATUS_OK)
|
||||
return arrayList;
|
||||
|
||||
|
@ -410,9 +410,9 @@ namespace sd {
|
|||
auto arr = var->getNDArray();
|
||||
if (!arr->isAttached()) {
|
||||
var->markRemovable(false);
|
||||
arrayList->push_back(arr);
|
||||
arrayList.push_back(arr);
|
||||
} else {
|
||||
arrayList->push_back(arr->detach());
|
||||
arrayList.push_back(arr->detach());
|
||||
}
|
||||
} else
|
||||
break;
|
||||
|
@ -423,4 +423,4 @@ namespace sd {
|
|||
|
||||
BUILD_SINGLE_TEMPLATE(template Nd4jStatus LegacyRandomOp::validateAndExecute_, (Context&), FLOAT_TYPES);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -44,9 +44,7 @@ TEST_F(AttentionTests, basic_dot_product_attention) {
|
|||
|
||||
sd::ops::dot_product_attention op;
|
||||
auto result = op.evaluate({&queries, &keys, &values}, {1, 0});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
delete result;
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -72,9 +70,7 @@ TEST_F(AttentionTests, basic_dot_product_attention_with_weights) {
|
|||
|
||||
sd::ops::dot_product_attention op;
|
||||
auto result = op.evaluate({&queries, &keys, &values}, {1, 1});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
delete result;
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
}
|
||||
|
||||
TEST_F(AttentionTests, basic_dot_product_attention_with_mask) {
|
||||
|
@ -86,9 +82,7 @@ TEST_F(AttentionTests, basic_dot_product_attention_with_mask) {
|
|||
|
||||
sd::ops::dot_product_attention op;
|
||||
auto result = op.evaluate({&queries, &keys, &values, &mask}, {1, 0});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
delete result;
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -118,9 +112,7 @@ TEST_F(AttentionTests, multi_head_input_dot_product_attention_with_mask) {
|
|||
|
||||
sd::ops::dot_product_attention op;
|
||||
auto result = op.evaluate({&queries, &keys, &values, &mask}, {1, 0});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
delete result;
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -154,9 +146,7 @@ TEST_F(AttentionTests, basic_multi_head_dot_product_attention) {
|
|||
|
||||
sd::ops::multi_head_dot_product_attention op;
|
||||
auto result = op.evaluate({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo}, {1, 0});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
delete result;
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -198,9 +188,7 @@ TEST_F(AttentionTests, basic_multi_head_dot_product_attention_with_mask) {
|
|||
|
||||
sd::ops::multi_head_dot_product_attention op;
|
||||
auto result = op.evaluate({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo, &mask}, {1, 0});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
delete result;
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
}
|
||||
|
||||
/*
|
||||
|
|
|
@ -39,13 +39,11 @@ TEST_F(BackpropTests, Test_Add_1) {
|
|||
sd::ops::add_bp op;
|
||||
auto result = op.evaluate({&x, &y, &e});
|
||||
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto eps = result->at(0);
|
||||
auto grad = result->at(1);
|
||||
auto eps = result.at(0);
|
||||
auto grad = result.at(1);
|
||||
|
||||
ASSERT_TRUE(x.isSameShape(eps));
|
||||
ASSERT_TRUE(y.isSameShape(grad));
|
||||
|
||||
delete result;
|
||||
}
|
|
@ -137,13 +137,12 @@ TEST_F(BooleanOpsTests, test_where_1) {
|
|||
sd::ops::choose op;
|
||||
|
||||
auto result = op.evaluate({&x, &y}, {3});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
//z->printIndexedBuffer("z");
|
||||
|
||||
ASSERT_EQ(e, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
|
|
|
@ -48,9 +48,9 @@ TEST_F(BroadcastableOpsTests, Test_Add_1) {
|
|||
sd::ops::add op;
|
||||
auto result = op.evaluate({&x, &y});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
//exp.printIndexedBuffer("E A");
|
||||
//z->printIndexedBuffer("Z");
|
||||
|
@ -58,7 +58,6 @@ TEST_F(BroadcastableOpsTests, Test_Add_1) {
|
|||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
|
||||
|
@ -75,14 +74,12 @@ TEST_F(BroadcastableOpsTests, Test_Multiply_1) {
|
|||
sd::ops::multiply op;
|
||||
auto result = op.evaluate({&x, &y});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
|
||||
|
@ -100,14 +97,13 @@ TEST_F(BroadcastableOpsTests, Test_SquaredSubtract_1) {
|
|||
sd::ops::squaredsubtract op;
|
||||
auto result = op.evaluate({&x, &y});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
|
||||
|
@ -119,14 +115,12 @@ TEST_F(BroadcastableOpsTests, Test_ScalarBroadcast_1) {
|
|||
sd::ops::subtract op;
|
||||
auto result = op.evaluate({&x, &y});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
|
||||
|
@ -138,14 +132,12 @@ TEST_F(BroadcastableOpsTests, Test_ScalarBroadcast_2) {
|
|||
sd::ops::add op;
|
||||
auto result = op.evaluate({&x, &y});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
|
||||
|
@ -156,14 +148,12 @@ TEST_F(BroadcastableOpsTests, Test_Maximum_1) {
|
|||
|
||||
sd::ops::maximum op;
|
||||
auto result = op.evaluate({&x, &row});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
|
||||
|
@ -174,15 +164,14 @@ TEST_F(BroadcastableOpsTests, Test_Minimum_1) {
|
|||
|
||||
sd::ops::minimum op;
|
||||
auto result = op.evaluate({&x, &col});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
|
||||
|
@ -283,14 +272,13 @@ TEST_F(BroadcastableOpsTests, Test_Scalar_Add_1) {
|
|||
|
||||
sd::ops::add op;
|
||||
auto result = op.evaluate({&x, &y});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
|
||||
|
@ -333,11 +321,9 @@ TEST_F(BroadcastableOpsTests, Test_Subtract_2) {
|
|||
|
||||
sd::ops::subtract op;
|
||||
auto result = op.evaluate({&x, &y});
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(e.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(BroadcastableOpsTests, Test_Subtract_3) {
|
||||
|
@ -511,13 +497,12 @@ TEST_F(BroadcastableOpsTests, Test_Multiply_7) {
|
|||
|
||||
sd::ops::multiply op;
|
||||
auto result = op.evaluate({&x, &y});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(e.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(BroadcastableOpsTests, Test_Multiply_8) {
|
||||
|
@ -527,13 +512,11 @@ TEST_F(BroadcastableOpsTests, Test_Multiply_8) {
|
|||
|
||||
sd::ops::multiply op;
|
||||
auto result = op.evaluate({&x, &y});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(e.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
@ -606,14 +589,12 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_3) {
|
|||
sd::ops::maximum op;
|
||||
auto result = op.evaluate({&x, &y});
|
||||
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(e.isSameShape(z));
|
||||
ASSERT_TRUE(e.equalsTo(*z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(BroadcastableOpsTests, broadcast_empty_4) {
|
||||
|
@ -625,14 +606,13 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_4) {
|
|||
sd::ops::maximum op;
|
||||
auto result = op.evaluate({&x, &y});
|
||||
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(e.isSameShape(z));
|
||||
ASSERT_TRUE(e.equalsTo(*z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(BroadcastableOpsTests, broadcast_empty_5) {
|
||||
|
@ -644,14 +624,13 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_5) {
|
|||
sd::ops::realdiv op;
|
||||
auto result = op.evaluate({&x, &y});
|
||||
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(e.isSameShape(z));
|
||||
ASSERT_TRUE(e.equalsTo(*z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(BroadcastableOpsTests, broadcast_empty_6) {
|
||||
|
@ -663,14 +642,13 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_6) {
|
|||
sd::ops::realdiv op;
|
||||
auto result = op.evaluate({&x, &y});
|
||||
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(e.isSameShape(z));
|
||||
ASSERT_TRUE(e.equalsTo(*z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(BroadcastableOpsTests, broadcast_empty_7) {
|
||||
|
@ -682,14 +660,12 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_7) {
|
|||
sd::ops::realdiv op;
|
||||
auto result = op.evaluate({&x, &y});
|
||||
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(e.isSameShape(z));
|
||||
ASSERT_TRUE(e.equalsTo(*z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
|
||||
|
@ -718,15 +694,13 @@ TEST_F(BroadcastableOpsTests, broadcast_bool_empty_2) {
|
|||
sd::ops::greater op;
|
||||
auto result = op.evaluate({&x, &y});
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
// z->printShapeInfo("z");
|
||||
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
ASSERT_TRUE(e.isSameShape(z));
|
||||
ASSERT_TRUE(e.equalsTo(*z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(BroadcastableOpsTests, broadcast_bool_1) {
|
||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -48,9 +48,7 @@ TEST_F(DataTypesValidationTests, Basic_Test_1) {
|
|||
sd::ops::conv2d op;
|
||||
auto result = op.evaluate({&input, &weights}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_VALIDATION, result->status());
|
||||
|
||||
delete result;
|
||||
ASSERT_EQ(ND4J_STATUS_VALIDATION, result.status());
|
||||
}
|
||||
|
||||
TEST_F(DataTypesValidationTests, Basic_Test_2) {
|
||||
|
@ -63,13 +61,12 @@ TEST_F(DataTypesValidationTests, Basic_Test_2) {
|
|||
|
||||
sd::ops::conv2d op;
|
||||
auto result = op.evaluate({&input, &weights}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -47,13 +47,11 @@ TEST_F(DeclarableOpsTests16, scatter_upd_1) {
|
|||
|
||||
sd::ops::scatter_upd op;
|
||||
auto result = op.evaluate({ &x, &y, &w });
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_EQ(e, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests16, scatter_upd_2) {
|
||||
|
@ -67,13 +65,11 @@ TEST_F(DeclarableOpsTests16, scatter_upd_2) {
|
|||
|
||||
sd::ops::scatter_upd op;
|
||||
auto result = op.evaluate({ &x, &indices, &updates });
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_EQ(e, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests16, scatter_upd_3) {
|
||||
|
@ -136,13 +132,11 @@ TEST_F(DeclarableOpsTests16, test_hamming_distance_1) {
|
|||
|
||||
sd::ops::bits_hamming_distance op;
|
||||
auto result = op.evaluate({ &x, &y });
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_EQ(e, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests16, test_knn_mindistance_1) {
|
||||
|
@ -167,10 +161,8 @@ TEST_F(DeclarableOpsTests16, test_empty_cast_1) {
|
|||
|
||||
sd::ops::cast op;
|
||||
auto result = op.evaluate({&x}, {10});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(e, *result->at(0));
|
||||
|
||||
delete result;
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
ASSERT_EQ(e, *result.at(0));
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests16, test_range_1) {
|
||||
|
@ -717,8 +709,6 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_1) {
|
|||
|
||||
}
|
||||
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_2) {
|
||||
|
||||
auto rgb = NDArrayFactory::create<float>('c', { 5, 3, 4 },
|
||||
|
@ -767,7 +757,6 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_2) {
|
|||
|
||||
}
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_3) {
|
||||
|
||||
auto rgb = NDArrayFactory::create<float>('c', { 4, 3 },
|
||||
|
@ -798,7 +787,6 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_3) {
|
|||
|
||||
}
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_4) {
|
||||
|
||||
auto rgb = NDArrayFactory::create<float>('c', { 3, 4 },
|
||||
|
@ -826,7 +814,6 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_4) {
|
|||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
ASSERT_TRUE(expected.equalsTo(actual));
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -850,7 +837,6 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_5) {
|
|||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
ASSERT_TRUE(expected.equalsTo(actual));
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_6) {
|
||||
|
@ -891,8 +877,6 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_6) {
|
|||
|
||||
}
|
||||
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_1) {
|
||||
|
||||
auto yiqs = NDArrayFactory::create<float>('c', { 5, 4, 3 }, {
|
||||
|
@ -937,8 +921,6 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_1) {
|
|||
|
||||
}
|
||||
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_2) {
|
||||
|
||||
auto yiqs = NDArrayFactory::create<float>('c', { 5, 3, 4 }, {
|
||||
|
@ -983,7 +965,6 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_2) {
|
|||
|
||||
}
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_3) {
|
||||
|
||||
auto yiqs = NDArrayFactory::create<float>('c', { 4, 3 }, {
|
||||
|
@ -1010,7 +991,6 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_3) {
|
|||
|
||||
}
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_4) {
|
||||
|
||||
auto yiqs = NDArrayFactory::create<float>('c', { 3, 4 }, {
|
||||
|
@ -1037,8 +1017,6 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_4) {
|
|||
|
||||
}
|
||||
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_5) {
|
||||
|
||||
auto yiqs = NDArrayFactory::create<float>('c', { 3 }, {
|
||||
|
@ -1061,7 +1039,6 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_5) {
|
|||
#endif
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
ASSERT_TRUE(expected.equalsTo(actual));
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_6) {
|
||||
|
@ -1096,5 +1073,4 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_6) {
|
|||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
ASSERT_TRUE(expected.equalsTo(actual));
|
||||
|
||||
}
|
||||
|
|
|
@ -49,9 +49,7 @@ TEST_F(DeclarableOpsTests17, test_sparse_to_dense_1) {
|
|||
|
||||
sd::ops::compat_sparse_to_dense op;
|
||||
auto result = op.evaluate({&ranges, &shape, &values, &def});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
delete result;
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests17, test_sparse_to_dense_2) {
|
||||
|
@ -64,9 +62,8 @@ TEST_F(DeclarableOpsTests17, test_sparse_to_dense_2) {
|
|||
|
||||
sd::ops::compat_sparse_to_dense op;
|
||||
auto result = op.evaluate({&ranges, &shape, &values, &def});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests17, test_compat_string_split_1) {
|
||||
|
@ -78,11 +75,11 @@ TEST_F(DeclarableOpsTests17, test_compat_string_split_1) {
|
|||
|
||||
sd::ops::compat_string_split op;
|
||||
auto result = op.evaluate({&x, &delimiter});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(2, result->size());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
ASSERT_EQ(2, result.size());
|
||||
|
||||
auto z0 = result->at(0);
|
||||
auto z1 = result->at(1);
|
||||
auto z0 = result.at(0);
|
||||
auto z1 = result.at(1);
|
||||
|
||||
ASSERT_TRUE(exp0.isSameShape(z0));
|
||||
ASSERT_TRUE(exp1.isSameShape(z1));
|
||||
|
@ -90,5 +87,4 @@ TEST_F(DeclarableOpsTests17, test_compat_string_split_1) {
|
|||
ASSERT_EQ(exp0, *z0);
|
||||
ASSERT_EQ(exp1, *z1);
|
||||
|
||||
delete result;
|
||||
}
|
|
@ -62,10 +62,8 @@ TEST_F(DeclarableOpsTests19, test_conv1d_bp_1) {
|
|||
|
||||
sd::ops::conv1d_bp op;
|
||||
auto result = op.evaluate({&t, &u, &v}, {3, 2, 0, 1, 2,0});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests19, test_squeeze_1) {
|
||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -51,14 +51,12 @@ TEST_F(DeclarableOpsTestsCuda1, Test_CHOOSE_SCALAR_LARGE) {
|
|||
sd::ops::choose op;
|
||||
//greater than test
|
||||
auto result = op.evaluate({&x}, {0.0},{3});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(1);
|
||||
auto z = result.at(1);
|
||||
|
||||
ASSERT_EQ(148,z->e<double>(0));
|
||||
//ASSERT_TRUE(exp.isSameShape(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
/*
|
||||
|
|
|
@ -67,9 +67,9 @@ TEST_F(EmptyTests, Test_Concat_1) {
|
|||
|
||||
sd::ops::concat op;
|
||||
auto result = op.evaluate({empty, vector}, {}, {0});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
// z->printShapeInfo("z shape");
|
||||
// z->printIndexedBuffer("z buffr");
|
||||
|
@ -78,7 +78,6 @@ TEST_F(EmptyTests, Test_Concat_1) {
|
|||
|
||||
delete empty;
|
||||
delete vector;
|
||||
delete result;
|
||||
}
|
||||
|
||||
|
||||
|
@ -92,9 +91,9 @@ TEST_F(EmptyTests, Test_Concat_2) {
|
|||
|
||||
sd::ops::concat op;
|
||||
auto result = op.evaluate({empty, scalar1, scalar2}, {}, {0});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
// z->printShapeInfo("z shape");
|
||||
// z->printIndexedBuffer("z buffr");
|
||||
|
@ -104,7 +103,6 @@ TEST_F(EmptyTests, Test_Concat_2) {
|
|||
delete empty;
|
||||
delete scalar1;
|
||||
delete scalar2;
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(EmptyTests, Test_Concat_3) {
|
||||
|
@ -117,13 +115,12 @@ TEST_F(EmptyTests, Test_Concat_3) {
|
|||
|
||||
sd::ops::concat op;
|
||||
auto result = op.evaluate({&empty, &scalar1, &scalar2}, {}, {0});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_EQ(exp, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(EmptyTests, Test_Concat_4) {
|
||||
|
@ -136,13 +133,11 @@ TEST_F(EmptyTests, Test_Concat_4) {
|
|||
|
||||
sd::ops::concat op;
|
||||
auto result = op.evaluate({&scalar1, &empty, &scalar2}, {}, {0});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_EQ(exp, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(EmptyTests, Test_Reshape_1) {
|
||||
|
@ -153,12 +148,11 @@ TEST_F(EmptyTests, Test_Reshape_1) {
|
|||
sd::ops::reshape op;
|
||||
auto result = op.evaluate({&vector, empty}, {}, {});
|
||||
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
ASSERT_EQ(exp, *result->at(0));
|
||||
ASSERT_EQ(exp, *result.at(0));
|
||||
|
||||
delete empty;
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(EmptyTests, Test_Reshape_3) {
|
||||
|
@ -168,14 +162,13 @@ TEST_F(EmptyTests, Test_Reshape_3) {
|
|||
|
||||
sd::ops::reshape op;
|
||||
auto result = op.evaluate({&x, &y}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(e.isSameShape(z));
|
||||
ASSERT_EQ(e, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(EmptyTests, Test_dup_1) {
|
||||
|
@ -197,12 +190,11 @@ TEST_F(EmptyTests, test_empty_scatter_1) {
|
|||
|
||||
sd::ops::scatter_upd op;
|
||||
auto result = op.evaluate({&x, &indices, &updates}, {}, {}, {true});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
ASSERT_EQ(x, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(EmptyTests, test_empty_scatter_2) {
|
||||
|
@ -288,17 +280,15 @@ TEST_F(EmptyTests, test_empty_reshape_1) {
|
|||
|
||||
sd::ops::reshape op;
|
||||
auto result0 = op.evaluate({&x0, &shape0}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result0->status());
|
||||
auto z0 = result0->at(0);
|
||||
ASSERT_EQ(Status::OK(), result0.status());
|
||||
auto z0 = result0.at(0);
|
||||
ASSERT_EQ(e0, *z0);
|
||||
|
||||
auto result1 = op.evaluate({&x1, &shape1}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result1->status());
|
||||
auto z1 = result1->at(0);
|
||||
ASSERT_EQ(Status::OK(), result1.status());
|
||||
auto z1 = result1.at(0);
|
||||
ASSERT_EQ(e1, *z1);
|
||||
|
||||
delete result0;
|
||||
delete result1;
|
||||
}
|
||||
|
||||
|
||||
|
@ -309,12 +299,11 @@ TEST_F(EmptyTests, test_empty_matmul_1) {
|
|||
|
||||
sd::ops::matmul op;
|
||||
auto result = op.evaluate({&x, &y}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
ASSERT_EQ(e, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(EmptyTests, test_empty_matmul_2) {
|
||||
|
@ -324,10 +313,8 @@ TEST_F(EmptyTests, test_empty_matmul_2) {
|
|||
|
||||
sd::ops::matmul op;
|
||||
auto result = op.evaluate({&x, &y}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
ASSERT_EQ(e, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
|
|
@ -1889,20 +1889,19 @@ TEST_F(HelpersTests1, OpArgsHolder_test3) {
|
|||
OpArgsHolder holderFF({&input}, {}, {2, 3});
|
||||
sd::ops::tile opFF; // the kind of op doesn't matter, we simply check here whether op.execute() works with OpArgsHolder correctly
|
||||
auto results = opFF.execute(holderFF);
|
||||
auto tiled = results->at(0);
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
auto tiled = results.at(0);
|
||||
ASSERT_EQ(Status::OK(), results.status());
|
||||
ASSERT_TRUE(exp.isSameShape(tiled));
|
||||
ASSERT_TRUE(exp.equalsTo(tiled));
|
||||
delete results;
|
||||
|
||||
|
||||
OpArgsHolder holderBP = holderFF.createArgsHolderForBP({&gradO}, true);
|
||||
sd::ops::tile_bp opBP;
|
||||
results = opBP.execute(holderBP);
|
||||
auto gradI = results->at(0);
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
auto gradI = results.at(0);
|
||||
ASSERT_EQ(Status::OK(), results.status());
|
||||
ASSERT_TRUE(gradIExp.isSameShape(gradI));
|
||||
ASSERT_TRUE(gradIExp.equalsTo(gradI));
|
||||
delete results;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -47,13 +47,13 @@ TEST_F(IndexingTests, StridedSlice_1) {
|
|||
sd::ops::strided_slice op;
|
||||
|
||||
auto result = op.evaluate({&x, &begin, &end, &strides}, {}, {0,0,0,0,0}); //, 2,2,0, 3,3,3, 1,1,1});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -66,14 +66,14 @@ TEST_F(IndexingTests, StridedSlice_2) {
|
|||
sd::ops::strided_slice op;
|
||||
|
||||
auto result = op.evaluate({&x}, {}, {0,0,0,0,0, 3,2,0, 5,5,3, 1,1,1});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -86,14 +86,14 @@ TEST_F(IndexingTests, StridedSlice_3) {
|
|||
sd::ops::strided_slice op;
|
||||
|
||||
auto result = op.evaluate({&x}, {}, {0,0,0,0,0, 3,2,0, 5,5,3, 1,1,2});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -109,15 +109,15 @@ TEST_F(IndexingTests, SimpleSlice_1) {
|
|||
sd::ops::slice op;
|
||||
|
||||
auto result = op.evaluate({&input}, {}, {1,0,0, 1,1,3});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -135,15 +135,15 @@ TEST_F(IndexingTests, SimpleSlice_2) {
|
|||
sd::ops::slice op;
|
||||
|
||||
auto result = op.evaluate({&input}, {}, {1,0,0, 1,2,3});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(IndexingTests, SimpleSlice_3) {
|
||||
|
@ -160,15 +160,15 @@ TEST_F(IndexingTests, SimpleSlice_3) {
|
|||
sd::ops::slice op;
|
||||
|
||||
auto result = op.evaluate({&input}, {}, {1,0,0, 2,1,3});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(IndexingTests, SimpleSlice_4) {
|
||||
|
@ -180,14 +180,14 @@ TEST_F(IndexingTests, SimpleSlice_4) {
|
|||
sd::ops::slice op;
|
||||
|
||||
auto result = op.evaluate({&input, &start, &stop});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -204,16 +204,16 @@ TEST_F(IndexingTests, MaskedSlice_0) {
|
|||
sd::ops::strided_slice op;
|
||||
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,0, 1, 2, 1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
// z->printShapeInfo("z");
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -230,14 +230,14 @@ TEST_F(IndexingTests, MaskedSlice_00) {
|
|||
sd::ops::strided_slice op;
|
||||
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,0, 1, 1, 2, 3, 1, 1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -254,16 +254,16 @@ TEST_F(IndexingTests, MaskedSlice_1) {
|
|||
sd::ops::strided_slice op;
|
||||
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,1, 1, 2, 1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
// z->printShapeInfo("z");
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(IndexingTests, MaskedSlice_2) {
|
||||
|
@ -275,14 +275,14 @@ TEST_F(IndexingTests, MaskedSlice_2) {
|
|||
sd::ops::strided_slice op;
|
||||
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,1, 1, 0, 0, 3, 3, 3, 1, 1, 1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -295,14 +295,14 @@ TEST_F(IndexingTests, MaskedSlice_3) {
|
|||
sd::ops::strided_slice op;
|
||||
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,2, 1, 0, 0, 3, 3, 3, 1, 1, 1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -315,15 +315,15 @@ TEST_F(IndexingTests, MaskedSlice_4) {
|
|||
sd::ops::strided_slice op;
|
||||
auto result = op.evaluate({&matrix}, {}, {0,0,0,0, 3, 1, 0, 0, 3, 3, 3, 1, 1, 1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(IndexingTests, Live_Slice_1) {
|
||||
|
@ -338,16 +338,16 @@ TEST_F(IndexingTests, Live_Slice_1) {
|
|||
sd::ops::strided_slice op;
|
||||
auto result = op.evaluate({&matrix, &begin, &end, &stride}, {}, {0,0,0,0,3});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
// z->printShapeInfo("z shape");
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -361,14 +361,14 @@ TEST_F(IndexingTests, Test_StridedSlice_1) {
|
|||
sd::ops::strided_slice op;
|
||||
auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(IndexingTests, Test_StridedSlice_2) {
|
||||
|
@ -381,16 +381,16 @@ TEST_F(IndexingTests, Test_StridedSlice_2) {
|
|||
sd::ops::strided_slice op;
|
||||
auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
// z->printIndexedBuffer("Z");
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -404,14 +404,14 @@ TEST_F(IndexingTests, Test_StridedSlice_3) {
|
|||
sd::ops::strided_slice op;
|
||||
auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -426,16 +426,16 @@ TEST_F(IndexingTests, Test_StridedSlice_4) {
|
|||
auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
|
||||
// auto result = op.execute({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1, 0, 1, 1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
//z->printIndexedBuffer("Z");
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(IndexingTests, Test_Subarray_Strided_1) {
|
||||
|
@ -458,13 +458,13 @@ TEST_F(IndexingTests, MaskedSlice_5) {
|
|||
sd::ops::strided_slice<float> op;
|
||||
auto result = op.execute({&matrix}, {}, {0,0,0,0,2, 1, 0, 0, 3, 3, 3});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
*/
|
|
@ -64,13 +64,13 @@ TEST_F(LegacyOpsTests, TransformTests_2) {
|
|||
sd::ops::LegacyTransformSameOp op(transform::Neg); // Neg
|
||||
auto result = op.evaluate({&x}, {}, {});
|
||||
|
||||
ASSERT_EQ(1, result->size());
|
||||
ASSERT_EQ(1, result.size());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(LegacyOpsTests, Reciprocal_1) {
|
||||
|
@ -121,12 +121,12 @@ TEST_F(LegacyOpsTests, PWT_Tests_2) {
|
|||
sd::ops::LegacyPairwiseTransformOp op(pairwise::Multiply); // Multiply
|
||||
auto result = op.evaluate({&x, &y}, {}, {});
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
//z->printBuffer("Z");
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(LegacyOpsTests, Scalar_Test_1) {
|
||||
|
@ -154,10 +154,10 @@ TEST_F(LegacyOpsTests, Scalar_Test_2) {
|
|||
sd::ops::LegacyScalarOp op(scalar::Add, y);
|
||||
auto result = op.evaluate({&x}, {}, {});
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -169,14 +169,14 @@ TEST_F(LegacyOpsTests, ReduceTests_1) {
|
|||
|
||||
auto result = op.evaluate({&x}, {}, {});
|
||||
|
||||
ASSERT_EQ(1, result->size());
|
||||
ASSERT_EQ(1, result.size());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
// z->printBuffer("ReduceTest1");
|
||||
ASSERT_TRUE(z->isScalar());
|
||||
ASSERT_NEAR(x.sumNumber().e<float>(0), z->e<float>(0), 1e-5f);
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -188,16 +188,16 @@ TEST_F(LegacyOpsTests, ReduceTests_2) {
|
|||
auto axis = NDArrayFactory::create<Nd4jLong>('c', {1}, {1});
|
||||
auto result = op.evaluate({&x, &axis}, {}, {});
|
||||
|
||||
ASSERT_EQ(1, result->size());
|
||||
ASSERT_EQ(1, result.size());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
auto exp = x.reduceAlongDimension(reduce::Sum, {1});
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -209,15 +209,15 @@ TEST_F(LegacyOpsTests, ReduceTests_3) {
|
|||
|
||||
sd::ops::LegacyReduceSameOp op(reduce::Sum);
|
||||
auto result = op.evaluate({&x, &indices}, {}, {});
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
auto exp = x.reduceAlongDimension(reduce::Sum,{1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -229,16 +229,16 @@ TEST_F(LegacyOpsTests, ReduceTests_4) {
|
|||
|
||||
sd::ops::LegacyReduceSameOp op(reduce::Sum);
|
||||
auto result = op.evaluate({&x, &indices}, {}, {}, {true});
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
auto exp = x.reduceAlongDimension(reduce::Sum, {1}, true);
|
||||
// indices.printShapeInfo("Indices shape");
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
// z->printIndexedBuffer("Output reduce 4");
|
||||
// exp.printIndexedBuffer("Expected reduce 4");
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(LegacyOpsTests, ReduceTests_5) {
|
||||
|
@ -249,14 +249,14 @@ TEST_F(LegacyOpsTests, ReduceTests_5) {
|
|||
|
||||
auto result = op.evaluate({&x});
|
||||
|
||||
ASSERT_EQ(1, result->size());
|
||||
ASSERT_EQ(1, result.size());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
// z->printBuffer("ReduceTest1");
|
||||
ASSERT_TRUE(z->isScalar());
|
||||
ASSERT_NEAR(x.meanNumber().e<float>(0), z->e<float>(0), 1e-5f);
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -268,16 +268,16 @@ TEST_F(LegacyOpsTests, ReduceTests_6) {
|
|||
|
||||
auto result = op.evaluate({&x, &axis}, {}, {});
|
||||
|
||||
ASSERT_EQ(1, result->size());
|
||||
ASSERT_EQ(1, result.size());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
auto exp = x.reduceAlongDimension(reduce::Mean, {1});
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -289,15 +289,15 @@ TEST_F(LegacyOpsTests, ReduceTests_7) {
|
|||
|
||||
sd::ops::LegacyReduceFloatOp op(reduce::Mean);
|
||||
auto result = op.evaluate({&x, &indices}, {}, {});
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
auto exp = x.reduceAlongDimension(reduce::Mean,{1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -309,17 +309,17 @@ TEST_F(LegacyOpsTests, ReduceTests_8) {
|
|||
|
||||
sd::ops::LegacyReduceFloatOp op(reduce::Mean);
|
||||
auto result = op.evaluate({&x, &indices}, {}, {}, {true});
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
auto exp = x.reduceAlongDimension(reduce::Mean, {1}, true);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
// z->printIndexedBuffer("Reduce8 output");
|
||||
// z->printShapeInfo("Reduce8 shape");
|
||||
// exp.printShapeInfo("Reduce8 expected shape");
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -331,14 +331,14 @@ TEST_F(LegacyOpsTests, IndexReduceTests_1) {
|
|||
|
||||
auto result = op.evaluate({&x}, {}, {});
|
||||
|
||||
ASSERT_EQ(1, result->size());
|
||||
ASSERT_EQ(1, result.size());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(z->isScalar());
|
||||
ASSERT_EQ(24, z->e<int>(0));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -351,9 +351,9 @@ TEST_F(LegacyOpsTests, IndexReduceTests_2) {
|
|||
|
||||
auto result = op.evaluate({&x, &indices}, {}, {});
|
||||
|
||||
ASSERT_EQ(1, result->size());
|
||||
ASSERT_EQ(1, result.size());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
// z->printIndexedBuffer("Hello indexreduce2");
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
//ASSERT_EQ(4, z->e<int>(0));
|
||||
|
@ -362,7 +362,7 @@ TEST_F(LegacyOpsTests, IndexReduceTests_2) {
|
|||
//ASSERT_EQ(4, z->e<int>(3));
|
||||
//ASSERT_EQ(4, z->e<int>(4));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(LegacyOpsTests, BroadcastingTests_1) {
|
||||
|
|
|
@ -39,7 +39,7 @@ TEST_F(ListOperationsTests, BasicTest_Write_1) {
|
|||
|
||||
auto result = op.execute(&list, {&x}, {}, {1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
ASSERT_EQ(1, list.elements());
|
||||
|
||||
|
@ -47,8 +47,8 @@ TEST_F(ListOperationsTests, BasicTest_Write_1) {
|
|||
|
||||
ASSERT_EQ(2, list.elements());
|
||||
|
||||
delete result;
|
||||
delete result2;
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(ListOperationsTests, BasicTest_Stack_1) {
|
||||
|
@ -66,15 +66,15 @@ TEST_F(ListOperationsTests, BasicTest_Stack_1) {
|
|||
|
||||
auto result = op.execute(&list, {}, {}, {1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
// z->printShapeInfo();
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(ListOperationsTests, BasicTest_UnStackList_1) {
|
||||
|
@ -93,10 +93,10 @@ TEST_F(ListOperationsTests, BasicTest_UnStackList_1) {
|
|||
|
||||
auto result = op.execute(&list, {&x}, {}, {0});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
ASSERT_EQ(list.elements(), 10);
|
||||
|
||||
// auto z = result->at(0);
|
||||
// auto z = result.at(0);
|
||||
// z->printShapeInfo("The first of");
|
||||
// ASSERT_TRUE(exp.isSameShape(z));
|
||||
// ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
@ -107,7 +107,7 @@ TEST_F(ListOperationsTests, BasicTest_UnStackList_1) {
|
|||
delete row;
|
||||
}
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
//TEST_F(ListOperationsTests, BasicTest_UnStackList_2) {
|
||||
|
@ -126,20 +126,20 @@ TEST_F(ListOperationsTests, BasicTest_UnStackList_1) {
|
|||
//
|
||||
// auto result = op.execute(nullptr, {&x}, {}, {0});
|
||||
//
|
||||
// ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
// ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
// ASSERT_EQ(result->size(), 10);
|
||||
//
|
||||
// // auto z = result->at(0);
|
||||
// // auto z = result.at(0);
|
||||
//// z->printShapeInfo("The first of");
|
||||
//// ASSERT_TRUE(exp.isSameShape(z));
|
||||
//// ASSERT_TRUE(exp.equalsTo(z));
|
||||
// for (int e = 0; e < 10; e++) {
|
||||
// auto row = result->at(e);
|
||||
// auto row = result.at(e);
|
||||
// ASSERT_TRUE(row->equalsTo(tads->at(e)));
|
||||
// //list.write(e, row);
|
||||
// }
|
||||
//
|
||||
// delete result;
|
||||
//
|
||||
// delete tads;
|
||||
//}
|
||||
|
||||
|
@ -160,14 +160,14 @@ TEST_F(ListOperationsTests, BasicTest_Read_1) {
|
|||
|
||||
auto result = op.execute(&list, {}, {}, {4});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(ListOperationsTests, BasicTest_Pick_1) {
|
||||
|
@ -192,14 +192,14 @@ TEST_F(ListOperationsTests, BasicTest_Pick_1) {
|
|||
sd::ops::pick_list op;
|
||||
auto result = op.execute(&list, {}, {}, {1, 1, 3, 3});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(ListOperationsTests, BasicTest_Size_1) {
|
||||
|
@ -217,14 +217,14 @@ TEST_F(ListOperationsTests, BasicTest_Size_1) {
|
|||
|
||||
auto result = op.execute(&list, {}, {}, {1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(ListOperationsTests, BasicTest_Create_1) {
|
||||
|
@ -235,12 +235,12 @@ TEST_F(ListOperationsTests, BasicTest_Create_1) {
|
|||
|
||||
auto result = op.execute(nullptr, {&matrix}, {}, {1, 1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
// we return flow as well
|
||||
ASSERT_EQ(1, result->size());
|
||||
ASSERT_EQ(1, result.size());
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(ListOperationsTests, BasicTest_Split_1) {
|
||||
|
@ -283,7 +283,7 @@ TEST_F(ListOperationsTests, BasicTest_Split_1) {
|
|||
|
||||
sd::ops::split_list op;
|
||||
auto result = op.execute(&list, {&matrix, &lengths}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
ASSERT_EQ(3, list.height());
|
||||
|
||||
|
@ -296,7 +296,7 @@ TEST_F(ListOperationsTests, BasicTest_Split_1) {
|
|||
ASSERT_TRUE(exp2.isSameShape(list.readRaw(2)));
|
||||
ASSERT_TRUE(exp2.equalsTo(list.readRaw(2)));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(ListOperationsTests, BasicTest_Scatter_1) {
|
||||
|
@ -319,7 +319,7 @@ TEST_F(ListOperationsTests, BasicTest_Scatter_1) {
|
|||
sd::ops::scatter_list op;
|
||||
auto result = op.execute(&list, {&indices, &matrix, &s}, {}, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
for (int e = 0; e < 10; e++) {
|
||||
auto row = tads.at(9 - e);
|
||||
|
@ -329,7 +329,7 @@ TEST_F(ListOperationsTests, BasicTest_Scatter_1) {
|
|||
|
||||
ASSERT_TRUE(chunk->equalsTo(row));
|
||||
}
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(ListOperationsTests, BasicTest_Clone_1) {
|
||||
|
@ -385,10 +385,10 @@ TEST_F(ListOperationsTests, BasicTest_Gather_1) {
|
|||
sd::ops::gather_list op;
|
||||
auto result = op.execute(&list, {&indices}, {}, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(1, result->size());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
ASSERT_EQ(1, result.size());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
|
||||
|
@ -397,7 +397,7 @@ TEST_F(ListOperationsTests, BasicTest_Gather_1) {
|
|||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(ListOperationsTests, GraphTests_Sequential_1) {
|
||||
|
|
|
@ -134,13 +134,11 @@ TEST_F(MultiDataTypeTests, Basic_Test_7) {
|
|||
|
||||
sd::ops::add op;
|
||||
auto result = op.evaluate({&x, &y});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_EQ(e, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -66,7 +66,7 @@ TEST_F(NlpTests, basic_sg_hs_test_1) {
|
|||
|
||||
sd::ops::skipgram op;
|
||||
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true);
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto row0 = syn0({0,1, 0,0}, true);
|
||||
auto row1 = syn1({1,2, 0,0}, true);
|
||||
|
@ -74,7 +74,7 @@ TEST_F(NlpTests, basic_sg_hs_test_1) {
|
|||
ASSERT_EQ(exp0, row0);
|
||||
ASSERT_EQ(exp1, row1);
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(NlpTests, basic_sg_hs_test_2) {
|
||||
|
@ -107,7 +107,7 @@ TEST_F(NlpTests, basic_sg_hs_test_2) {
|
|||
|
||||
sd::ops::skipgram op;
|
||||
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true);
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto row0 = syn0({0,1, 0,0}, true);
|
||||
auto row1 = syn1({1,2, 0,0}, true);
|
||||
|
@ -117,7 +117,7 @@ TEST_F(NlpTests, basic_sg_hs_test_2) {
|
|||
ASSERT_EQ(exp1, row1);
|
||||
ASSERT_EQ(exp2, row2);
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(NlpTests, basic_sg_hs_test_3) {
|
||||
|
@ -159,7 +159,7 @@ TEST_F(NlpTests, basic_sg_hs_test_3) {
|
|||
sd::ops::skipgram op;
|
||||
auto result0 = op.evaluate({&target, &ngStarter, &indices0, &codes00, &syn00, &syn10, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true);
|
||||
auto result1 = op.evaluate({&target, &ngStarter, &indices1, &codes01, &syn01, &syn11, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true);
|
||||
ASSERT_EQ(Status::OK(), result0->status());
|
||||
ASSERT_EQ(Status::OK(), result0.status());
|
||||
|
||||
auto row00 = syn00({0,1, 0,0}, true);
|
||||
auto row01 = syn01({0,1, 0,0}, true);
|
||||
|
@ -168,9 +168,6 @@ TEST_F(NlpTests, basic_sg_hs_test_3) {
|
|||
|
||||
ASSERT_EQ(row2, row1);
|
||||
ASSERT_EQ(row00, row01);
|
||||
|
||||
delete result0;
|
||||
delete result1;
|
||||
}
|
||||
|
||||
TEST_F(NlpTests, basic_sg_hs_ns_test_1) {
|
||||
|
@ -192,9 +189,9 @@ TEST_F(NlpTests, basic_sg_hs_ns_test_1) {
|
|||
|
||||
sd::ops::skipgram op;
|
||||
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {3}, {false}, {}, true);
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(NlpTests, basic_sg_ns_test_1) {
|
||||
|
@ -227,14 +224,14 @@ TEST_F(NlpTests, basic_sg_ns_test_1) {
|
|||
|
||||
sd::ops::skipgram op;
|
||||
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {1, 1}, {false}, {}, true);
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto row0 = syn0({1,2, 0,0}, true);
|
||||
|
||||
ASSERT_EQ(exp0, row0);
|
||||
ASSERT_FALSE(syn1Neg2.equalsTo(syn1Neg, 1e-6));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(NlpTests, basic_cb_hs_test_1) {
|
||||
|
@ -269,7 +266,7 @@ TEST_F(NlpTests, basic_cb_hs_test_1) {
|
|||
|
||||
sd::ops::cbow op;
|
||||
auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, {}, true);
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto row_s0_0 = syn0({0,1, 0,0}, true);
|
||||
auto row_s0_1 = syn0({1,2, 0,0}, true);
|
||||
|
@ -287,7 +284,7 @@ TEST_F(NlpTests, basic_cb_hs_test_1) {
|
|||
ASSERT_EQ(exp1, row_s1_5);
|
||||
ASSERT_EQ(exp2, row_s1_6);
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(NlpTests, basic_cb_ns_test_1) {
|
||||
|
@ -323,7 +320,7 @@ TEST_F(NlpTests, basic_cb_ns_test_1) {
|
|||
|
||||
sd::ops::cbow op;
|
||||
auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {1, 2, 0}, {true}, {}, true);
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto row_s0_0 = syn0({0,1, 0,0}, true);
|
||||
auto row_s0_1 = syn0({1,2, 0,0}, true);
|
||||
|
@ -339,7 +336,7 @@ TEST_F(NlpTests, basic_cb_ns_test_1) {
|
|||
ASSERT_EQ(exp0, row_s0_2);
|
||||
ASSERT_EQ(exp2, row_s1_6);
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(NlpTests, test_sg_hs_batch_1) {
|
||||
|
@ -372,7 +369,7 @@ TEST_F(NlpTests, test_sg_hs_batch_1) {
|
|||
|
||||
sd::ops::skipgram op;
|
||||
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false, true}, {}, true);
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto row0 = syn0({0,1, 0,0}, true);
|
||||
auto row1 = syn1({1,2, 0,0}, true);
|
||||
|
@ -382,7 +379,7 @@ TEST_F(NlpTests, test_sg_hs_batch_1) {
|
|||
ASSERT_TRUE(exp1.equalsTo(row1, 1e-6));
|
||||
ASSERT_TRUE(exp2.equalsTo(row2, 1e-6));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(NlpTests, test_sg_ns_batch_1) {
|
||||
|
@ -416,9 +413,9 @@ TEST_F(NlpTests, test_sg_ns_batch_1) {
|
|||
|
||||
sd::ops::skipgram op;
|
||||
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {4, 5}, {false, true}, {}, true);
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(NlpTests, test_cbow_hs_batch_1) {
|
||||
|
@ -449,7 +446,7 @@ TEST_F(NlpTests, test_cbow_hs_batch_1) {
|
|||
|
||||
sd::ops::cbow op;
|
||||
auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, {}, true);
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto exp0 = NDArrayFactory::create<float>('c', {1, 10});
|
||||
auto exp1 = NDArrayFactory::create<float>('c', {1, 10});
|
||||
|
@ -473,6 +470,5 @@ TEST_F(NlpTests, test_cbow_hs_batch_1) {
|
|||
ASSERT_EQ(exp1, row_s1_4);
|
||||
ASSERT_EQ(exp1, row_s1_5);
|
||||
ASSERT_EQ(exp2, row_s1_6);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -260,9 +260,9 @@ TEST_F(RNGTests, Test_Gaussian_21) {
|
|||
auto result = op.evaluate({&x0}, {}, {});
|
||||
//x0.printIndexedBuffer("X0 Normal");
|
||||
//x1.printIndexedBuffer("X1 Normal");
|
||||
ASSERT_TRUE(result->status() == Status::OK());
|
||||
auto mean = result->at(0);
|
||||
auto variance = result->at(1);
|
||||
ASSERT_TRUE(result.status() == Status::OK());
|
||||
auto mean = result.at(0);
|
||||
auto variance = result.at(1);
|
||||
|
||||
// mean->printIndexedBuffer("Mean");
|
||||
// variance->printIndexedBuffer("Variance");
|
||||
|
@ -270,7 +270,7 @@ TEST_F(RNGTests, Test_Gaussian_21) {
|
|||
ASSERT_NEAR(sd::math::nd4j_abs(mean->e<float>(0)), 0.f, 0.2f);
|
||||
ASSERT_NEAR(variance->e<float>(0), 1.0f, 0.2f);
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
#ifdef DEBUG_BUILD
|
||||
|
@ -292,15 +292,15 @@ TEST_F(RNGTests, Test_Gaussian_22) {
|
|||
auto result = op.evaluate({&x0}, {}, {});
|
||||
//x0.printIndexedBuffer("X0 Normal");
|
||||
//x1.printIndexedBuffer("X1 Normal");
|
||||
ASSERT_TRUE(result->status() == Status::OK());
|
||||
auto mean0 = result->at(0);
|
||||
auto variance0 = result->at(1);
|
||||
ASSERT_TRUE(result.status() == Status::OK());
|
||||
auto mean0 = result.at(0);
|
||||
auto variance0 = result.at(1);
|
||||
|
||||
//mean0->printIndexedBuffer("Mean");
|
||||
//variance0->printIndexedBuffer("Variance");
|
||||
ASSERT_NEAR(sd::math::nd4j_abs(mean0->e<float>(0)), 0.f, 1.0e-3f);
|
||||
ASSERT_NEAR(variance0->e<float>(0), 1.0f, 1.e-3f);
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(RNGTests, Test_Gaussian_3) {
|
||||
|
@ -413,18 +413,17 @@ TEST_F(RNGTests, Test_Truncated_21) {
|
|||
ASSERT_NEAR(deviation.e<float>(0), 2.f, 0.5);
|
||||
sd::ops::moments op;
|
||||
auto result = op.evaluate({&x0}, {}, {}, {}, {}, false);
|
||||
// result->at(0)->printBuffer("MEAN");
|
||||
// result->at(1)->printBuffer("VARIANCE");
|
||||
delete result;
|
||||
|
||||
// result.at(0)->printBuffer("MEAN");
|
||||
// result.at(1)->printBuffer("VARIANCE");
|
||||
|
||||
sd::ops::reduce_min minOp;
|
||||
sd::ops::reduce_max maxOp;
|
||||
|
||||
auto minRes = minOp.evaluate({&x1}, {}, {}, {});
|
||||
auto maxRes = maxOp.evaluate({&x0}, {}, {}, {});
|
||||
// minRes->at(0)->printBuffer("MIN for Truncated");
|
||||
// maxRes->at(0)->printBuffer("MAX for Truncated");
|
||||
|
||||
delete minRes;
|
||||
delete maxRes;
|
||||
}
|
||||
|
||||
TEST_F(RNGTests, Test_Truncated_22) {
|
||||
|
@ -460,18 +459,15 @@ TEST_F(RNGTests, Test_Truncated_22) {
|
|||
ASSERT_NEAR(deviation.e<float>(0), 4.f, 0.52);
|
||||
sd::ops::moments op;
|
||||
auto result = op.evaluate({&x0}, {}, {}, {}, {}, false);
|
||||
// result->at(0)->printBuffer("MEAN");
|
||||
// result->at(1)->printBuffer("VARIANCE");
|
||||
delete result;
|
||||
|
||||
sd::ops::reduce_min minOp;
|
||||
sd::ops::reduce_max maxOp;
|
||||
|
||||
auto minRes = minOp.evaluate({&x1}, {}, {}, {});
|
||||
auto maxRes = maxOp.evaluate({&x0}, {}, {}, {});
|
||||
// minRes->at(0)->printBuffer("MIN for Truncated2");
|
||||
// maxRes->at(0)->printBuffer("MAX for Truncated2");
|
||||
|
||||
delete minRes;
|
||||
delete maxRes;
|
||||
}
|
||||
|
||||
TEST_F(RNGTests, Test_Truncated_23) {
|
||||
|
@ -509,16 +505,14 @@ TEST_F(RNGTests, Test_Truncated_23) {
|
|||
auto result = op.evaluate({&x0});
|
||||
// result->at(0)->printBuffer("MEAN");
|
||||
// result->at(1)->printBuffer("VARIANCE");
|
||||
delete result;
|
||||
sd::ops::reduce_min minOp;
|
||||
sd::ops::reduce_max maxOp;
|
||||
|
||||
auto minRes = minOp.evaluate({&x1}, {}, {}, {});
|
||||
auto maxRes = maxOp.evaluate({&x0}, {}, {}, {});
|
||||
// minRes->at(0)->printBuffer("MIN for Truncated3");
|
||||
// maxRes->at(0)->printBuffer("MAX for Truncated3");
|
||||
|
||||
delete minRes;
|
||||
delete maxRes;
|
||||
}
|
||||
|
||||
TEST_F(RNGTests, Test_Truncated_3) {
|
||||
|
@ -568,15 +562,15 @@ TEST_F(RNGTests, Test_Uniform_2) {
|
|||
auto op = new sd::ops::LegacyRandomOp(0);
|
||||
auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {});
|
||||
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(x1.isSameShape(z));
|
||||
ASSERT_TRUE(x1.equalsTo(z));
|
||||
|
||||
delete op;
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(RNGTests, Test_Gaussian_2) {
|
||||
|
@ -588,15 +582,15 @@ TEST_F(RNGTests, Test_Gaussian_2) {
|
|||
auto op = new sd::ops::LegacyRandomOp(random::GaussianDistribution);
|
||||
auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {});
|
||||
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(x1.isSameShape(z));
|
||||
ASSERT_TRUE(x1.equalsTo(z));
|
||||
|
||||
delete op;
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(RNGTests, Test_LogNorm_2) {
|
||||
|
@ -608,15 +602,15 @@ TEST_F(RNGTests, Test_LogNorm_2) {
|
|||
auto op = new sd::ops::LegacyRandomOp(random::LogNormalDistribution);
|
||||
auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {});
|
||||
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(x1.isSameShape(z));
|
||||
ASSERT_TRUE(x1.equalsTo(z));
|
||||
|
||||
delete op;
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(RNGTests, Test_TruncatedNorm_2) {
|
||||
|
@ -628,14 +622,14 @@ TEST_F(RNGTests, Test_TruncatedNorm_2) {
|
|||
auto op = new sd::ops::LegacyRandomOp(random::TruncatedNormalDistribution);
|
||||
auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {});
|
||||
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(x1.isSameShape(z));
|
||||
ASSERT_TRUE(x1.equalsTo(z));
|
||||
delete op;
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -648,15 +642,15 @@ TEST_F(RNGTests, Test_Binomial_2) {
|
|||
auto op = new sd::ops::LegacyRandomOp(random::BinomialDistributionEx);
|
||||
auto result = op->execute(_rngA, {&input}, {0.5f}, {3});
|
||||
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(x1.isSameShape(z));
|
||||
ASSERT_TRUE(x1.equalsTo(z));
|
||||
|
||||
delete op;
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -669,15 +663,15 @@ TEST_F(RNGTests, Test_Bernoulli_2) {
|
|||
auto op = new sd::ops::LegacyRandomOp(random::BernoulliDistribution);
|
||||
auto result = op->execute(_rngA, {&input}, {0.5f}, {});
|
||||
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(x1.isSameShape(z));
|
||||
ASSERT_TRUE(x1.equalsTo(z));
|
||||
|
||||
delete op;
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(RNGTests, Test_GaussianDistribution_1) {
|
||||
|
@ -687,9 +681,9 @@ TEST_F(RNGTests, Test_GaussianDistribution_1) {
|
|||
|
||||
sd::ops::random_normal op;
|
||||
auto result = op.evaluate({&x}, {0.0, 1.0f}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
ASSERT_TRUE(exp0.isSameShape(z));
|
||||
ASSERT_FALSE(exp0.equalsTo(z));
|
||||
|
||||
|
@ -698,7 +692,7 @@ TEST_F(RNGTests, Test_GaussianDistribution_1) {
|
|||
ASSERT_FALSE(nexp1->equalsTo(z));
|
||||
ASSERT_FALSE(nexp2->equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(RNGTests, Test_BernoulliDistribution_1) {
|
||||
|
@ -708,9 +702,9 @@ TEST_F(RNGTests, Test_BernoulliDistribution_1) {
|
|||
|
||||
sd::ops::random_bernoulli op;
|
||||
auto result = op.evaluate({&x}, {0.5f}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_FALSE(exp0.equalsTo(z));
|
||||
|
||||
|
@ -718,7 +712,7 @@ TEST_F(RNGTests, Test_BernoulliDistribution_1) {
|
|||
ASSERT_FALSE(nexp1->equalsTo(z));
|
||||
ASSERT_FALSE(nexp2->equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -729,9 +723,9 @@ TEST_F(RNGTests, Test_ExponentialDistribution_1) {
|
|||
|
||||
sd::ops::random_exponential op;
|
||||
auto result = op.evaluate({&x}, {0.25f}, {0});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
ASSERT_TRUE(exp0.isSameShape(z));
|
||||
ASSERT_FALSE(exp0.equalsTo(z));
|
||||
|
||||
|
@ -740,7 +734,7 @@ TEST_F(RNGTests, Test_ExponentialDistribution_1) {
|
|||
ASSERT_FALSE(nexp1->equalsTo(z));
|
||||
ASSERT_FALSE(nexp2->equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(RNGTests, Test_ExponentialDistribution_2) {
|
||||
|
@ -753,9 +747,9 @@ TEST_F(RNGTests, Test_ExponentialDistribution_2) {
|
|||
|
||||
sd::ops::random_exponential op;
|
||||
auto result = op.evaluate({&x, &y}, {0.25f}, {0});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
ASSERT_TRUE(exp0.isSameShape(z));
|
||||
ASSERT_FALSE(exp0.equalsTo(z));
|
||||
|
||||
|
@ -764,7 +758,7 @@ TEST_F(RNGTests, Test_ExponentialDistribution_2) {
|
|||
ASSERT_FALSE(nexp1->equalsTo(z));
|
||||
ASSERT_FALSE(nexp2->equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(RNGTests, Test_PoissonDistribution_1) {
|
||||
|
@ -777,14 +771,14 @@ TEST_F(RNGTests, Test_PoissonDistribution_1) {
|
|||
|
||||
sd::ops::random_poisson op;
|
||||
auto result = op.evaluate({&x, &la}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
// z->printIndexedBuffer("Poisson distribution");
|
||||
ASSERT_TRUE(exp0.isSameShape(z));
|
||||
ASSERT_FALSE(exp0.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(RNGTests, Test_GammaDistribution_1) {
|
||||
|
@ -797,14 +791,14 @@ TEST_F(RNGTests, Test_GammaDistribution_1) {
|
|||
|
||||
sd::ops::random_gamma op;
|
||||
auto result = op.evaluate({&x, &al}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
// z->printIndexedBuffer("Gamma distribution");
|
||||
ASSERT_TRUE(exp0.isSameShape(z));
|
||||
ASSERT_FALSE(exp0.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(RNGTests, Test_GammaDistribution_2) {
|
||||
|
@ -818,14 +812,14 @@ TEST_F(RNGTests, Test_GammaDistribution_2) {
|
|||
|
||||
sd::ops::random_gamma op;
|
||||
auto result = op.evaluate({&x, &al, &be}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
// z->printIndexedBuffer("Gamma distribution");
|
||||
ASSERT_TRUE(exp0.isSameShape(z));
|
||||
ASSERT_FALSE(exp0.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(RNGTests, Test_GammaDistribution_3) {
|
||||
|
@ -839,14 +833,14 @@ TEST_F(RNGTests, Test_GammaDistribution_3) {
|
|||
|
||||
sd::ops::random_gamma op;
|
||||
auto result = op.evaluate({&x, &al, &be}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
// z->printIndexedBuffer("Gamma distribution");
|
||||
ASSERT_TRUE(exp0.isSameShape(z));
|
||||
ASSERT_FALSE(exp0.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(RNGTests, Test_UniformDistribution_04) {
|
||||
|
@ -858,13 +852,13 @@ TEST_F(RNGTests, Test_UniformDistribution_04) {
|
|||
|
||||
sd::ops::randomuniform op;
|
||||
auto result = op.evaluate({&x, &al, &be}, {}, {DataType::INT32});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
ASSERT_TRUE(exp0.isSameShape(z));
|
||||
ASSERT_FALSE(exp0.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
namespace sd {
|
||||
|
@ -1021,12 +1015,12 @@ TEST_F(RNGTests, test_multinomial_1) {
|
|||
NDArray expectedZ('c', { 3, 3 }, { 0., 0, 0, 0, 0, 0, 0, 0, 0 }, sd::DataType::INT64);
|
||||
|
||||
auto result = op.evaluate({ &probsZ, &samples }, { }, { 1, INT64 });
|
||||
auto outputZ = result->at(0);
|
||||
auto outputZ = result.at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
ASSERT_TRUE(expectedZ.isSameShape(outputZ));
|
||||
ASSERT_TRUE(expectedZ.equalsTo(outputZ));
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(RNGTests, test_multinomial_2) {
|
||||
|
@ -1117,8 +1111,8 @@ TEST_F(RNGTests, test_multinomial_5) {
|
|||
}
|
||||
|
||||
auto resultR = op.evaluate({ &probs, &samples }, { }, { 1 });
|
||||
auto outputR = resultR->at(0);
|
||||
ASSERT_EQ(Status::OK(), resultR->status());
|
||||
auto outputR = resultR.at(0);
|
||||
ASSERT_EQ(Status::OK(), resultR.status());
|
||||
|
||||
deviation = outputR->varianceNumber(variance::SummaryStatsStandardDeviation, false);
|
||||
mean = outputR->meanNumber();
|
||||
|
@ -1131,7 +1125,6 @@ TEST_F(RNGTests, test_multinomial_5) {
|
|||
ASSERT_TRUE(value >= 0 && value < ClassValue);
|
||||
}
|
||||
|
||||
delete resultR;
|
||||
}
|
||||
|
||||
|
||||
|
@ -1150,8 +1143,8 @@ TEST_F(RNGTests, test_multinomial_6) {
|
|||
NDArray probsR('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, sd::DataType::FLOAT32);
|
||||
|
||||
auto resultR = op.evaluate({ &probsR, &samples }, { }, { 0 });
|
||||
auto outputR = resultR->at(0);
|
||||
ASSERT_EQ(Status::OK(), resultR->status());
|
||||
auto outputR = resultR.at(0);
|
||||
ASSERT_EQ(Status::OK(), resultR.status());
|
||||
|
||||
NDArray countsR('c', { ClassValue }, { 0., 0, 0, 0, 0 }, sd::DataType::DOUBLE);
|
||||
|
||||
|
@ -1175,7 +1168,7 @@ TEST_F(RNGTests, test_multinomial_6) {
|
|||
ASSERT_NEAR(1.2175, deviation.e<double>(0), 45e-3); // 1000000 35e-3);
|
||||
ASSERT_NEAR(2.906, mean.e<double>(0), 45e-3); // 1000000 35e-3);
|
||||
|
||||
delete resultR;
|
||||
|
||||
|
||||
RandomGenerator rng(1234, 1234);
|
||||
NDArray probs('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, sd::DataType::FLOAT32);
|
||||
|
|
|
@ -96,14 +96,14 @@ TEST_F(ScalarTests, Test_Concat_1) {
|
|||
sd::ops::concat op;
|
||||
auto result = op.evaluate({&t, &u, &v}, {}, {0});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -116,15 +116,15 @@ TEST_F(ScalarTests, Test_Concat_2) {
|
|||
sd::ops::concat op;
|
||||
auto result = op.evaluate({&t, &u, &v}, {}, {0});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
// z->printIndexedBuffer();
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -137,16 +137,16 @@ TEST_F(ScalarTests, Test_Concat_3) {
|
|||
sd::ops::concat op;
|
||||
auto result = op.evaluate({&t, &u, &v}, {}, {0});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
//z->printShapeInfo("z");
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(ScalarTests, Test_ExpandDims_1) {
|
||||
|
@ -156,14 +156,14 @@ TEST_F(ScalarTests, Test_ExpandDims_1) {
|
|||
sd::ops::expand_dims op;
|
||||
auto result = op.evaluate({&x}, {}, {0});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(ScalarTests, Test_Squeeze_1) {
|
||||
|
@ -172,14 +172,14 @@ TEST_F(ScalarTests, Test_Squeeze_1) {
|
|||
|
||||
sd::ops::squeeze op;
|
||||
auto result = op.evaluate({&x}, {}, {});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -189,14 +189,14 @@ TEST_F(ScalarTests, Test_Reshape_1) {
|
|||
|
||||
sd::ops::reshape op;
|
||||
auto result = op.evaluate({&x}, {}, {-99, 1, 1, 1});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -206,14 +206,14 @@ TEST_F(ScalarTests, Test_Permute_1) {
|
|||
|
||||
sd::ops::permute op;
|
||||
auto result = op.evaluate({&x}, {}, {0});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(ScalarTests, Test_Concat_Scalar_1) {
|
||||
|
@ -225,14 +225,13 @@ TEST_F(ScalarTests, Test_Concat_Scalar_1) {
|
|||
|
||||
sd::ops::concat op;
|
||||
auto result = op.evaluate({&t, &u, &v, &w}, {}, {0});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
|
||||
|
@ -245,12 +244,11 @@ TEST_F(ScalarTests, Test_Concat_Scalar_2) {
|
|||
|
||||
sd::ops::concat op;
|
||||
auto result = op.evaluate({&t, &u, &v, &w}, {}, {1});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
|
@ -308,14 +308,13 @@ TEST_F(ShapeTests, Tests_Transpose_119_2) {
|
|||
|
||||
sd::ops::transpose op;
|
||||
auto result = op.evaluate({&x});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(ShapeTests, Tests_Transpose_119_3) {
|
||||
|
|
|
@ -70,14 +70,14 @@ TEST_F(SingleDimTests, Test_Concat_1) {
|
|||
sd::ops::concat op;
|
||||
auto result = op.evaluate({&x, &y}, {}, {0});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(SingleDimTests, Test_Reduce_1) {
|
||||
|
@ -104,14 +104,14 @@ TEST_F(SingleDimTests, Test_ExpandDims_1) {
|
|||
sd::ops::expand_dims op;
|
||||
auto result = op.evaluate({&x}, {}, {0});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -122,14 +122,14 @@ TEST_F(SingleDimTests, Test_ExpandDims_2) {
|
|||
sd::ops::expand_dims op;
|
||||
auto result = op.evaluate({&x}, {}, {1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -142,14 +142,14 @@ TEST_F(SingleDimTests, Test_Squeeze_1) {
|
|||
sd::ops::squeeze op;
|
||||
auto result = op.evaluate({&x}, {}, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_EQ(exp.rankOf(), z->rankOf());
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(SingleDimTests, Test_Squeeze_2) {
|
||||
|
@ -158,14 +158,14 @@ TEST_F(SingleDimTests, Test_Squeeze_2) {
|
|||
|
||||
sd::ops::squeeze op;
|
||||
auto result = op.evaluate({&x}, {}, {});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(SingleDimTests, Test_Reshape_1) {
|
||||
|
@ -174,14 +174,14 @@ TEST_F(SingleDimTests, Test_Reshape_1) {
|
|||
|
||||
sd::ops::reshape op;
|
||||
auto result = op.evaluate({&x}, {}, {-99, 3});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(SingleDimTests, Test_Reshape_2) {
|
||||
|
@ -190,14 +190,14 @@ TEST_F(SingleDimTests, Test_Reshape_2) {
|
|||
|
||||
sd::ops::reshape op;
|
||||
auto result = op.evaluate({&x}, {}, {-99, 1, 3});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -207,12 +207,12 @@ TEST_F(SingleDimTests, Test_Permute_1) {
|
|||
|
||||
sd::ops::permute op;
|
||||
auto result = op.evaluate({&x}, {}, {0});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
Loading…
Reference in New Issue