Improve ResultSet usage in libnd4j (#281)
* libnd4j profiling DeclarableOp and Tests by replacing return ResultSet pointer by instance Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j profiling semantic change in tests cases Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j some corrections to make new ResultSet semantic works, fixed one test Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j more tests fixes Signed-off-by: Oleg <oleg.semeniv@gmail.com> * - correct copy and move assignment operators of ResultSet class Signed-off-by: Yurii <iuriish@yahoo.com> Co-authored-by: Yurii <iuriish@yahoo.com> Co-authored-by: raver119 <raver119@gmail.com>master
parent
57210b936c
commit
c3223dbc7a
|
@ -15,8 +15,8 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// This class is suited for execution results representation.
|
// This class is suited for execution results representation.
|
||||||
//
|
//
|
||||||
// PLESE NOTE: It will delete all stored NDArrays upon destructor call
|
// PLESE NOTE: It will delete all stored NDArrays upon destructor call
|
||||||
//
|
//
|
||||||
// @author raver119@gmail.com
|
// @author raver119@gmail.com
|
||||||
|
@ -33,13 +33,15 @@
|
||||||
namespace sd {
|
namespace sd {
|
||||||
|
|
||||||
class NDArray; // forward declaration of template class NDArray
|
class NDArray; // forward declaration of template class NDArray
|
||||||
|
|
||||||
class ND4J_EXPORT ResultSet {
|
class ND4J_EXPORT ResultSet {
|
||||||
private:
|
private:
|
||||||
std::vector<sd::NDArray *> _content;
|
std::vector<sd::NDArray *> _content;
|
||||||
Nd4jStatus _status = ND4J_STATUS_OK;
|
Nd4jStatus _status = ND4J_STATUS_OK;
|
||||||
bool _removable = true;
|
bool _removable = true;
|
||||||
|
|
||||||
|
void delContent();
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit ResultSet();
|
explicit ResultSet();
|
||||||
|
|
||||||
|
@ -56,7 +58,7 @@ namespace sd {
|
||||||
|
|
||||||
// move assignment operator
|
// move assignment operator
|
||||||
ResultSet& operator=(ResultSet&& other) noexcept;
|
ResultSet& operator=(ResultSet&& other) noexcept;
|
||||||
|
|
||||||
~ResultSet();
|
~ResultSet();
|
||||||
|
|
||||||
int size();
|
int size();
|
||||||
|
|
|
@ -160,9 +160,7 @@ namespace sd {
|
||||||
|
|
||||||
auto result = op.evaluate(inputs);
|
auto result = op.evaluate(inputs);
|
||||||
|
|
||||||
auto array = new NDArray(result->at(0)->dup());
|
auto array = new NDArray(result.at(0)->dup());
|
||||||
|
|
||||||
delete result;
|
|
||||||
|
|
||||||
return array;
|
return array;
|
||||||
}
|
}
|
||||||
|
|
|
@ -77,15 +77,16 @@ namespace sd {
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
// move assignment operator
|
// move assignment operator
|
||||||
ResultSet& ResultSet::operator=(ResultSet&& other) noexcept {
|
ResultSet& ResultSet::operator=(ResultSet&& other) noexcept {
|
||||||
|
|
||||||
if (this == &other)
|
if (this == &other)
|
||||||
return *this;
|
return *this;
|
||||||
|
|
||||||
this->~ResultSet();
|
delContent();
|
||||||
|
|
||||||
_content = std::move(other._content);
|
_content = std::move(other._content);
|
||||||
|
|
||||||
_status = other._status;
|
_status = other._status;
|
||||||
_removable = other._removable;
|
_removable = other._removable;
|
||||||
other._removable = false;
|
other._removable = false;
|
||||||
|
@ -98,10 +99,10 @@ namespace sd {
|
||||||
if (this == &other)
|
if (this == &other)
|
||||||
return *this;
|
return *this;
|
||||||
|
|
||||||
this->~ResultSet();
|
delContent();
|
||||||
|
|
||||||
for (const auto v:other._content)
|
for (const auto v : other._content)
|
||||||
_content.emplace_back(v);
|
_content.push_back(v);
|
||||||
|
|
||||||
_status = other._status;
|
_status = other._status;
|
||||||
_removable = false;
|
_removable = false;
|
||||||
|
@ -109,11 +110,15 @@ namespace sd {
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ResultSet::delContent() {
|
||||||
|
if (_removable)
|
||||||
|
for (auto v : _content)
|
||||||
|
delete v;
|
||||||
|
}
|
||||||
|
|
||||||
ResultSet::~ResultSet() {
|
ResultSet::~ResultSet() {
|
||||||
if (_removable)
|
|
||||||
for (auto v: _content)
|
delContent();
|
||||||
delete v;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ResultSet::setNonRemovable() {
|
void ResultSet::setNonRemovable() {
|
||||||
|
|
|
@ -60,7 +60,7 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
|
||||||
fillGradArrays(loss, std::vector<NDArray*>(&inArrsBP[numInArrsFF], &inArrsBP[numInArrsFF + numInGradArrsBP]));
|
fillGradArrays(loss, std::vector<NDArray*>(&inArrsBP[numInArrsFF], &inArrsBP[numInArrsFF + numInGradArrsBP]));
|
||||||
|
|
||||||
// back prop pass
|
// back prop pass
|
||||||
ResultSet* outArrsBP = opBP.execute(argsHolderBP); // number of output arrays in back prop = numInArrsFF;
|
ResultSet outArrsBP = opBP.execute(argsHolderBP); // number of output arrays in back prop = numInArrsFF;
|
||||||
|
|
||||||
NDArray tmpScalar(sd::DataType::DOUBLE, inArrsFF[0]->getContext()); // scalar = 0
|
NDArray tmpScalar(sd::DataType::DOUBLE, inArrsFF[0]->getContext()); // scalar = 0
|
||||||
|
|
||||||
|
@ -78,18 +78,17 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
|
||||||
|
|
||||||
// add epsilon, feed forward
|
// add epsilon, feed forward
|
||||||
inArrsFF[i]->p<double>(j, orig + EPSILON);
|
inArrsFF[i]->p<double>(j, orig + EPSILON);
|
||||||
ResultSet* outArrsFF = opFF.execute(argsHolderFF);
|
ResultSet outArrsFF = opFF.execute(argsHolderFF);
|
||||||
int numOutArrs = outArrsFF->size();
|
int numOutArrs = outArrsFF.size();
|
||||||
double scorePlus = 0.;
|
double scorePlus = 0.;
|
||||||
|
|
||||||
for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays
|
for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays
|
||||||
if(loss == SUM)
|
if(loss == SUM)
|
||||||
outArrsFF->at(k)->reduceNumber(reduce::Sum, tmpScalar);
|
outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar);
|
||||||
else
|
else
|
||||||
outArrsFF->at(k)->reduceNumber(reduce::Mean, tmpScalar);
|
outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar);
|
||||||
scorePlus += tmpScalar.e<double>(0);
|
scorePlus += tmpScalar.e<double>(0);
|
||||||
}
|
}
|
||||||
delete outArrsFF;
|
|
||||||
|
|
||||||
// subtract epsilon, feed forward
|
// subtract epsilon, feed forward
|
||||||
inArrsFF[i]->p<double>(j, orig - EPSILON);
|
inArrsFF[i]->p<double>(j, orig - EPSILON);
|
||||||
|
@ -98,12 +97,11 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
|
||||||
|
|
||||||
for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays
|
for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays
|
||||||
if(loss == SUM)
|
if(loss == SUM)
|
||||||
outArrsFF->at(k)->reduceNumber(reduce::Sum, tmpScalar);
|
outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar);
|
||||||
else
|
else
|
||||||
outArrsFF->at(k)->reduceNumber(reduce::Mean, tmpScalar);
|
outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar);
|
||||||
scoreMinus += tmpScalar.e<double>(0);
|
scoreMinus += tmpScalar.e<double>(0);
|
||||||
}
|
}
|
||||||
delete outArrsFF;
|
|
||||||
|
|
||||||
// restore initial element value
|
// restore initial element value
|
||||||
inArrsFF[i]->p<double>(j, orig);
|
inArrsFF[i]->p<double>(j, orig);
|
||||||
|
@ -116,7 +114,7 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
|
||||||
}
|
}
|
||||||
|
|
||||||
// get analytical gradient
|
// get analytical gradient
|
||||||
const double analyticGrad = outArrsBP->at(i)->e<double>(j);
|
const double analyticGrad = outArrsBP.at(i)->e<double>(j);
|
||||||
if(std::isnan(analyticGrad) || std::isinf(analyticGrad)) {
|
if(std::isnan(analyticGrad) || std::isinf(analyticGrad)) {
|
||||||
printf("GradCheck::checkGrad: got wrong value for analytical gradient for input array # %i and its element at position %lld ! \n", i, j);
|
printf("GradCheck::checkGrad: got wrong value for analytical gradient for input array # %i and its element at position %lld ! \n", i, j);
|
||||||
throw std::runtime_error("");
|
throw std::runtime_error("");
|
||||||
|
@ -138,13 +136,11 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
|
||||||
continue;
|
continue;
|
||||||
printf("numericalGrad = %f, analyticGrad = %f \n", numericalGrad, analyticGrad);
|
printf("numericalGrad = %f, analyticGrad = %f \n", numericalGrad, analyticGrad);
|
||||||
printf("GradCheck::checkGrad: got RELERROR = %f > MAXRELERROR(%f) for input array # %i and its element at position %lld ! \n", relError, MAXRELERR, i, j);
|
printf("GradCheck::checkGrad: got RELERROR = %f > MAXRELERROR(%f) for input array # %i and its element at position %lld ! \n", relError, MAXRELERR, i, j);
|
||||||
delete outArrsBP;
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
delete outArrsBP;
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -45,8 +45,8 @@ namespace sd {
|
||||||
Nd4jStatus execute(Context* block) override;
|
Nd4jStatus execute(Context* block) override;
|
||||||
|
|
||||||
|
|
||||||
ResultSet* execute(NDArrayList* list, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs);
|
ResultSet execute(NDArrayList* list, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs);
|
||||||
ResultSet* execute(NDArrayList* list, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs);
|
ResultSet execute(NDArrayList* list, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs);
|
||||||
|
|
||||||
ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override;
|
ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override;
|
||||||
};
|
};
|
||||||
|
|
|
@ -176,17 +176,17 @@ namespace sd {
|
||||||
|
|
||||||
Nd4jStatus execute(const std::vector<NDArray*> &inputs, const std::vector<NDArray*> &outputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs = std::vector<bool>(), const std::vector<sd::DataType> &dArgs = std::vector<sd::DataType>(), bool isInplace = false);
|
Nd4jStatus execute(const std::vector<NDArray*> &inputs, const std::vector<NDArray*> &outputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs = std::vector<bool>(), const std::vector<sd::DataType> &dArgs = std::vector<sd::DataType>(), bool isInplace = false);
|
||||||
|
|
||||||
|
sd::ResultSet evaluate(const std::vector<NDArray*> &inputs);
|
||||||
sd::ResultSet* evaluate(const std::vector<NDArray*> &inputs);
|
|
||||||
|
|
||||||
template <class T, typename = std::enable_if<DataTypeUtils::scalarTypesForExecution<T>::value>>
|
template <class T, typename = std::enable_if<DataTypeUtils::scalarTypesForExecution<T>::value>>
|
||||||
sd::ResultSet* evaluate(const std::vector<NDArray*> &inputs, std::initializer_list<T> args);
|
sd::ResultSet evaluate(const std::vector<NDArray*> &inputs, std::initializer_list<T> args);
|
||||||
|
|
||||||
sd::ResultSet* evaluate(const std::vector<NDArray*> &inputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs = std::vector<bool>(), const std::vector<sd::DataType> &dArgs = std::vector<sd::DataType>(), bool isInplace = false);
|
sd::ResultSet evaluate(const std::vector<NDArray*> &inputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs = std::vector<bool>(), const std::vector<sd::DataType> &dArgs = std::vector<sd::DataType>(), bool isInplace = false);
|
||||||
|
|
||||||
Nd4jStatus execute(sd::graph::RandomGenerator& rng, const std::vector<NDArray*>& inputs, const std::vector<NDArray*>& outputs, const std::vector<double>& tArgs, const std::vector<Nd4jLong>& iArgs, const std::vector<bool>& bArgs, const std::vector<sd::DataType> &dArgs = std::vector<sd::DataType>(), bool isInplace = false, sd::DataType type = sd::DataType::FLOAT32);
|
Nd4jStatus execute(sd::graph::RandomGenerator& rng, const std::vector<NDArray*>& inputs, const std::vector<NDArray*>& outputs, const std::vector<double>& tArgs, const std::vector<Nd4jLong>& iArgs, const std::vector<bool>& bArgs, const std::vector<sd::DataType> &dArgs = std::vector<sd::DataType>(), bool isInplace = false, sd::DataType type = sd::DataType::FLOAT32);
|
||||||
|
|
||||||
sd::ResultSet* execute(const sd::OpArgsHolder& holder, bool isInplace = false);
|
sd::ResultSet execute(const sd::OpArgsHolder& holder, bool isInplace = false);
|
||||||
|
|
||||||
|
|
||||||
// There methods provide various validation options
|
// There methods provide various validation options
|
||||||
Nd4jStatus validateNonEmptyInput(Context& block);
|
Nd4jStatus validateNonEmptyInput(Context& block);
|
||||||
|
|
|
@ -41,8 +41,9 @@ namespace sd {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Nd4jStatus validateAndExecute_(Context &block);
|
Nd4jStatus validateAndExecute_(Context &block);
|
||||||
|
|
||||||
sd::ResultSet* execute(sd::graph::RandomGenerator& rng, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs, bool isInplace = false);
|
sd::ResultSet execute(sd::graph::RandomGenerator& rng, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs, bool isInplace = false);
|
||||||
sd::ResultSet* execute(sd::graph::RandomGenerator& rng, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs, bool isInplace = false);
|
sd::ResultSet execute(sd::graph::RandomGenerator& rng, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs, bool isInplace = false);
|
||||||
|
|
||||||
Nd4jStatus execute(Context* block) override;
|
Nd4jStatus execute(Context* block) override;
|
||||||
|
|
||||||
ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override;
|
ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override;
|
||||||
|
|
|
@ -74,10 +74,10 @@ namespace sd {
|
||||||
// at first step we build fwd activation
|
// at first step we build fwd activation
|
||||||
sd::ops::crelu op;
|
sd::ops::crelu op;
|
||||||
auto tmpResult = op.evaluate({input});
|
auto tmpResult = op.evaluate({input});
|
||||||
if (tmpResult->status() != ND4J_STATUS_OK)
|
if (tmpResult.status() != ND4J_STATUS_OK)
|
||||||
return tmpResult->status();
|
return tmpResult.status();
|
||||||
|
|
||||||
auto actv = tmpResult->at(0);
|
auto actv = tmpResult.at(0);
|
||||||
|
|
||||||
// now we do RELU backward pass
|
// now we do RELU backward pass
|
||||||
//actv->applyPairwiseTransform(pairwise::RELUDerivativeE, *epsilon, nullptr);
|
//actv->applyPairwiseTransform(pairwise::RELUDerivativeE, *epsilon, nullptr);
|
||||||
|
@ -85,17 +85,15 @@ namespace sd {
|
||||||
// now we split updated array into 2 chunks along last dimension
|
// now we split updated array into 2 chunks along last dimension
|
||||||
sd::ops::concat_bp opc;
|
sd::ops::concat_bp opc;
|
||||||
auto dec = opc.evaluate({input, input, actv}, {-1});
|
auto dec = opc.evaluate({input, input, actv}, {-1});
|
||||||
if (dec->status() != ND4J_STATUS_OK)
|
if (dec.status() != ND4J_STATUS_OK)
|
||||||
return dec->status();
|
return dec.status();
|
||||||
|
|
||||||
// and now we subtract two parts of epsilons and pass result out
|
// and now we subtract two parts of epsilons and pass result out
|
||||||
auto pos = dec->at(0);
|
auto pos = dec.at(0);
|
||||||
auto neg = dec->at(1);
|
auto neg = dec.at(1);
|
||||||
|
|
||||||
pos->applyPairwiseTransform(sd::pairwise::Subtract, *neg, *epsilon);
|
pos->applyPairwiseTransform(sd::pairwise::Subtract, *neg, *epsilon);
|
||||||
|
|
||||||
delete tmpResult;
|
|
||||||
delete dec;
|
|
||||||
return ND4J_STATUS_OK;
|
return ND4J_STATUS_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -102,10 +102,12 @@ namespace sd {
|
||||||
REQUIRE_TRUE(block.width() == 1, 0, "Where op takes either 1 or 3 operands, But got %d operands instead", block.width());
|
REQUIRE_TRUE(block.width() == 1, 0, "Where op takes either 1 or 3 operands, But got %d operands instead", block.width());
|
||||||
// if (output->isEmpty())
|
// if (output->isEmpty())
|
||||||
Nd4jLong width = condition->rankOf();
|
Nd4jLong width = condition->rankOf();
|
||||||
|
|
||||||
sd::ops::Where op;
|
sd::ops::Where op;
|
||||||
std::unique_ptr<ResultSet> res(op.evaluate({condition}));
|
auto res(op.evaluate({condition}));
|
||||||
REQUIRE_OK(res->status());
|
REQUIRE_OK(res.status());
|
||||||
NDArray* whereTrue = res->at(0);
|
NDArray* whereTrue = res.at(0);
|
||||||
|
|
||||||
if (whereTrue->isEmpty())
|
if (whereTrue->isEmpty())
|
||||||
return ND4J_STATUS_OK;
|
return ND4J_STATUS_OK;
|
||||||
for (Nd4jLong outNext = 0; outNext < width; ++outNext) {
|
for (Nd4jLong outNext = 0; outNext < width; ++outNext) {
|
||||||
|
|
|
@ -65,11 +65,12 @@ namespace sd {
|
||||||
auto gradX = OUTPUT_VARIABLE(0);
|
auto gradX = OUTPUT_VARIABLE(0);
|
||||||
auto gradY = OUTPUT_VARIABLE(1);
|
auto gradY = OUTPUT_VARIABLE(1);
|
||||||
gradX->assign(epsNext);
|
gradX->assign(epsNext);
|
||||||
|
|
||||||
sd::ops::floormod op;
|
sd::ops::floormod op;
|
||||||
std::unique_ptr<ResultSet> tmpResult(op.evaluate({x, y}));
|
auto tmpResult(op.evaluate({x, y}));
|
||||||
|
|
||||||
if (gradY->rankOf() == gradX->rankOf())
|
if (gradY->rankOf() == gradX->rankOf())
|
||||||
epsNext->applyPairwiseTransform(pairwise::Multiply, *tmpResult->at(0), *gradY);
|
epsNext->applyPairwiseTransform(pairwise::Multiply, *tmpResult.at(0), *gradY);
|
||||||
else // epsNext is greater than gradY
|
else // epsNext is greater than gradY
|
||||||
{
|
{
|
||||||
std::vector<Nd4jLong> dims(epsNext->rankOf() * 2);
|
std::vector<Nd4jLong> dims(epsNext->rankOf() * 2);
|
||||||
|
@ -77,7 +78,7 @@ namespace sd {
|
||||||
for (Nd4jLong d = 0; d < gap; d++) {
|
for (Nd4jLong d = 0; d < gap; d++) {
|
||||||
dims[d * 2 + 1] = 1;
|
dims[d * 2 + 1] = 1;
|
||||||
}
|
}
|
||||||
auto tempIn((*tmpResult->at(0))(dims));
|
auto tempIn((*tmpResult.at(0))(dims));
|
||||||
(*epsNext)(dims).applyPairwiseTransform(pairwise::Multiply, tempIn, *gradY);
|
(*epsNext)(dims).applyPairwiseTransform(pairwise::Multiply, tempIn, *gradY);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -113,23 +113,21 @@ namespace ops {
|
||||||
originalIndices.linspace(0);
|
originalIndices.linspace(0);
|
||||||
ops::dynamic_partition op;
|
ops::dynamic_partition op;
|
||||||
auto res = op.evaluate({&originalIndices, indices}, {numPartition});
|
auto res = op.evaluate({&originalIndices, indices}, {numPartition});
|
||||||
REQUIRE_TRUE(res->status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning.");
|
REQUIRE_TRUE(res.status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning.");
|
||||||
ops::dynamic_stitch stichOp;
|
ops::dynamic_stitch stichOp;
|
||||||
std::vector<NDArray*> partitions(numPartition * 2);
|
std::vector<NDArray*> partitions(numPartition * 2);
|
||||||
for (size_t i = 0; i < res->size(); i++) {
|
for (size_t i = 0; i < res.size(); i++) {
|
||||||
partitions[i] = res->at(i);
|
partitions[i] = res.at(i);
|
||||||
partitions[i + numPartition] = gradOutList[i];
|
partitions[i + numPartition] = gradOutList[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
auto result = stichOp.evaluate(partitions, {numPartition});
|
auto result = stichOp.evaluate(partitions, {numPartition});
|
||||||
REQUIRE_TRUE(result->status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning.");
|
REQUIRE_TRUE(result.status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning.");
|
||||||
result->at(0)->reshapei(outputList[0]->getShapeAsVector());
|
result.at(0)->reshapei(outputList[0]->getShapeAsVector());
|
||||||
outputList[1]->assign(indices);
|
outputList[1]->assign(indices);
|
||||||
outputList[0]->assign(result->at(0));
|
outputList[0]->assign(result.at(0));
|
||||||
|
|
||||||
// helpers::dynamicPartitionFunctorBP(block.launchContext(), input, indices, gradOutList, outputList);
|
// helpers::dynamicPartitionFunctorBP(block.launchContext(), input, indices, gradOutList, outputList);
|
||||||
delete res;
|
|
||||||
delete result;
|
|
||||||
return ND4J_STATUS_OK;
|
return ND4J_STATUS_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -66,10 +66,10 @@ CUSTOM_OP_IMPL(embedding_lookup, 2, 1, false, 0, 1) {
|
||||||
|
|
||||||
sd::ops::gather op;
|
sd::ops::gather op;
|
||||||
|
|
||||||
std::unique_ptr<ResultSet> result(op.evaluate({input, indeces}, {0}));
|
auto result(op.evaluate({input, indeces}, {0}));
|
||||||
REQUIRE_TRUE(result->status() == Status::OK(), 0, "embedding_lookup: cannot retrieve results from gather op.");
|
REQUIRE_TRUE(result.status() == Status::OK(), 0, "embedding_lookup: cannot retrieve results from gather op.");
|
||||||
REQUIRE_TRUE(result->at(0)->isSameShape(output), 0, "embedding_lookup: wrong shape of return from gather op.");
|
REQUIRE_TRUE(result.at(0)->isSameShape(output), 0, "embedding_lookup: wrong shape of return from gather op.");
|
||||||
output->assign(result->at(0));
|
output->assign(result.at(0));
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -95,8 +95,8 @@ CUSTOM_OP_IMPL(dynamic_bidirectional_rnn, 7, 4, false, 0, 0) {
|
||||||
// forward steps
|
// forward steps
|
||||||
sd::ops::dynamic_rnn dynamicRnn;
|
sd::ops::dynamic_rnn dynamicRnn;
|
||||||
auto resultsFW = dynamicRnn.evaluate({x, WxFW, WhFW, bFW, h0FW, maxTimeStep}, {timeMajor});
|
auto resultsFW = dynamicRnn.evaluate({x, WxFW, WhFW, bFW, h0FW, maxTimeStep}, {timeMajor});
|
||||||
hFW->assign(resultsFW->at(0)); // [time x bS x numUnitsFW] or [bS x time x numUnitsFW]
|
hFW->assign(resultsFW.at(0)); // [time x bS x numUnitsFW] or [bS x time x numUnitsFW]
|
||||||
hFWFinal->assign(resultsFW->at(1));
|
hFWFinal->assign(resultsFW.at(1));
|
||||||
|
|
||||||
auto seqLen = maxTimeStep;
|
auto seqLen = maxTimeStep;
|
||||||
if(seqLen == nullptr) {
|
if(seqLen == nullptr) {
|
||||||
|
@ -108,22 +108,17 @@ CUSTOM_OP_IMPL(dynamic_bidirectional_rnn, 7, 4, false, 0, 0) {
|
||||||
// reverse x
|
// reverse x
|
||||||
sd::ops::reverse_sequence reverse;
|
sd::ops::reverse_sequence reverse;
|
||||||
auto resultsIn = timeMajor ? reverse.evaluate({x, seqLen}, {0, 1}) : reverse.evaluate({x, seqLen}, {1, 0});
|
auto resultsIn = timeMajor ? reverse.evaluate({x, seqLen}, {0, 1}) : reverse.evaluate({x, seqLen}, {1, 0});
|
||||||
REQUIRE_TRUE (resultsIn->status() == ND4J_STATUS_OK, 0, "dynamic_bidirectional_rnn: there is a problem with reverse on the sequence.");
|
REQUIRE_TRUE (resultsIn.status() == ND4J_STATUS_OK, 0, "dynamic_bidirectional_rnn: there is a problem with reverse on the sequence.");
|
||||||
auto revInput = resultsIn->at(0);
|
auto revInput = resultsIn.at(0);
|
||||||
|
|
||||||
// backward steps
|
// backward steps
|
||||||
auto resultsBW = dynamicRnn.evaluate({revInput, WxBW, WhBW, bBW, h0BW, maxTimeStep}, {timeMajor});
|
auto resultsBW = dynamicRnn.evaluate({revInput, WxBW, WhBW, bBW, h0BW, maxTimeStep}, {timeMajor});
|
||||||
auto hBWtemp = resultsBW->at(0); // [time x bS x numUnitsBW] or [ bS x time xnumUnitsBW]
|
auto hBWtemp = resultsBW.at(0); // [time x bS x numUnitsBW] or [ bS x time xnumUnitsBW]
|
||||||
hBWFinal->assign(resultsBW->at(1));
|
hBWFinal->assign(resultsBW.at(1));
|
||||||
|
|
||||||
// reverse hBWtemp
|
// reverse hBWtemp
|
||||||
auto resultsOut = timeMajor ? reverse.evaluate({hBWtemp, seqLen}, {0, 1}) : reverse.evaluate({hBWtemp, seqLen}, {1, 0});
|
auto resultsOut = timeMajor ? reverse.evaluate({hBWtemp, seqLen}, {0, 1}) : reverse.evaluate({hBWtemp, seqLen}, {1, 0});
|
||||||
hBW->assign(resultsOut->at(0));
|
hBW->assign(resultsOut.at(0));
|
||||||
|
|
||||||
delete resultsOut;
|
|
||||||
delete resultsBW;
|
|
||||||
delete resultsIn;
|
|
||||||
delete resultsFW;
|
|
||||||
|
|
||||||
if(seqLen != maxTimeStep)
|
if(seqLen != maxTimeStep)
|
||||||
delete seqLen;
|
delete seqLen;
|
||||||
|
@ -228,12 +223,6 @@ DECLARE_SHAPE_FN(dynamic_bidirectional_rnn) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -52,14 +52,13 @@ namespace helpers {
|
||||||
|
|
||||||
sd::ops::unique opUnique;
|
sd::ops::unique opUnique;
|
||||||
auto uResult = opUnique.evaluate({&arrayFull});
|
auto uResult = opUnique.evaluate({&arrayFull});
|
||||||
if (Status::OK() != uResult->status())
|
if (Status::OK() != uResult.status())
|
||||||
throw std::runtime_error("multiUnique: cannot execute unique op properly.");
|
throw std::runtime_error("multiUnique: cannot execute unique op properly.");
|
||||||
|
|
||||||
auto uniqueVals = uResult->at(0);
|
auto uniqueVals = uResult.at(0);
|
||||||
|
|
||||||
bool res = uniqueVals->lengthOf() == arrayFull.lengthOf();
|
bool res = uniqueVals->lengthOf() == arrayFull.lengthOf();
|
||||||
|
|
||||||
delete uResult;
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -64,7 +64,7 @@ namespace sd {
|
||||||
block.pushNDArrayListToVariableSpace(block.getNodeId(), 0, arrayList);
|
block.pushNDArrayListToVariableSpace(block.getNodeId(), 0, arrayList);
|
||||||
}
|
}
|
||||||
|
|
||||||
ResultSet* DeclarableListOp::execute(NDArrayList* list, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs) {
|
ResultSet DeclarableListOp::execute(NDArrayList* list, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs) {
|
||||||
std::vector<NDArray*> ins(inputs);
|
std::vector<NDArray*> ins(inputs);
|
||||||
std::vector<double> tas(tArgs);
|
std::vector<double> tas(tArgs);
|
||||||
std::vector<int> ias(iArgs);
|
std::vector<int> ias(iArgs);
|
||||||
|
@ -94,7 +94,7 @@ namespace sd {
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
ResultSet* DeclarableListOp::execute(NDArrayList* list, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs) {
|
ResultSet DeclarableListOp::execute(NDArrayList* list, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs) {
|
||||||
VariableSpace varSpace;
|
VariableSpace varSpace;
|
||||||
int nodeId = 119;
|
int nodeId = 119;
|
||||||
|
|
||||||
|
@ -132,8 +132,8 @@ namespace sd {
|
||||||
|
|
||||||
|
|
||||||
Nd4jStatus result = this->validateAndExecute(block);
|
Nd4jStatus result = this->validateAndExecute(block);
|
||||||
auto res = new ResultSet();
|
ResultSet res;
|
||||||
res->setStatus(result);
|
res.setStatus(result);
|
||||||
|
|
||||||
for (int e = 0; e < DataTypeUtils::max<int>(); e++) {
|
for (int e = 0; e < DataTypeUtils::max<int>(); e++) {
|
||||||
std::pair<int,int> pair(1, e);
|
std::pair<int,int> pair(1, e);
|
||||||
|
@ -143,10 +143,10 @@ namespace sd {
|
||||||
auto arr = var->getNDArray();
|
auto arr = var->getNDArray();
|
||||||
if (arr->isAttached()) {
|
if (arr->isAttached()) {
|
||||||
auto d = arr->detach();
|
auto d = arr->detach();
|
||||||
res->push_back(d);
|
res.push_back(d);
|
||||||
} else {
|
} else {
|
||||||
var->markRemovable(false);
|
var->markRemovable(false);
|
||||||
res->push_back(arr);
|
res.push_back(arr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else
|
} else
|
||||||
|
|
|
@ -962,12 +962,12 @@ namespace sd {
|
||||||
return execute(&ctx);
|
return execute(&ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
sd::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs) {
|
sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs) {
|
||||||
return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<sd::DataType>());
|
return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<sd::DataType>());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
sd::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<int> iArgs) {
|
sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<int> iArgs) {
|
||||||
std::vector<Nd4jLong> realArgs;
|
std::vector<Nd4jLong> realArgs;
|
||||||
for (auto v:iArgs)
|
for (auto v:iArgs)
|
||||||
realArgs.emplace_back(v);
|
realArgs.emplace_back(v);
|
||||||
|
@ -976,12 +976,12 @@ namespace sd {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
sd::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<Nd4jLong> iArgs) {
|
sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<Nd4jLong> iArgs) {
|
||||||
return evaluate(inputs, std::vector<double>(), iArgs, std::vector<bool>(), std::vector<sd::DataType>());
|
return evaluate(inputs, std::vector<double>(), iArgs, std::vector<bool>(), std::vector<sd::DataType>());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
sd::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<float> tArgs) {
|
sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<float> tArgs) {
|
||||||
std::vector<double> realArgs;
|
std::vector<double> realArgs;
|
||||||
for (auto v:tArgs)
|
for (auto v:tArgs)
|
||||||
realArgs.emplace_back(v);
|
realArgs.emplace_back(v);
|
||||||
|
@ -990,21 +990,21 @@ namespace sd {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
sd::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<double> tArgs) {
|
sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<double> tArgs) {
|
||||||
return evaluate(inputs, tArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<sd::DataType>());
|
return evaluate(inputs, tArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<sd::DataType>());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
sd::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<bool> bArgs) {
|
sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<bool> bArgs) {
|
||||||
return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), bArgs, std::vector<sd::DataType>());
|
return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), bArgs, std::vector<sd::DataType>());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
sd::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<sd::DataType> bArgs) {
|
sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<sd::DataType> bArgs) {
|
||||||
return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), std::vector<bool>(), bArgs);
|
return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), std::vector<bool>(), bArgs);
|
||||||
}
|
}
|
||||||
|
|
||||||
sd::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs, const std::vector<sd::DataType> &dArgs, bool isInplace) {
|
sd::ResultSet DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs, const std::vector<sd::DataType> &dArgs, bool isInplace) {
|
||||||
VariableSpace variableSpace;
|
VariableSpace variableSpace;
|
||||||
//ResultSet arrayList;
|
//ResultSet arrayList;
|
||||||
FlowPath fp;
|
FlowPath fp;
|
||||||
|
@ -1041,11 +1041,11 @@ namespace sd {
|
||||||
block.getDArguments()->push_back(dArgs.at(e));
|
block.getDArguments()->push_back(dArgs.at(e));
|
||||||
|
|
||||||
Nd4jStatus status = this->execute(&block);
|
Nd4jStatus status = this->execute(&block);
|
||||||
auto arrayList = new ResultSet();
|
ResultSet arrayList;
|
||||||
if (isInplace)
|
if (isInplace)
|
||||||
arrayList->setNonRemovable();
|
arrayList.setNonRemovable();
|
||||||
|
|
||||||
arrayList->setStatus(status);
|
arrayList.setStatus(status);
|
||||||
if (status != ND4J_STATUS_OK)
|
if (status != ND4J_STATUS_OK)
|
||||||
return arrayList;
|
return arrayList;
|
||||||
|
|
||||||
|
@ -1058,23 +1058,23 @@ namespace sd {
|
||||||
if (!arr->isAttached()) {
|
if (!arr->isAttached()) {
|
||||||
var->markRemovable(false);
|
var->markRemovable(false);
|
||||||
arr->setContext(sd::LaunchContext::defaultContext());
|
arr->setContext(sd::LaunchContext::defaultContext());
|
||||||
arrayList->push_back(arr);
|
arrayList.push_back(arr);
|
||||||
} else {
|
} else {
|
||||||
arrayList->push_back(arr->detach());
|
arrayList.push_back(arr->detach());
|
||||||
}
|
}
|
||||||
} else
|
} else
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (auto v:inputs) {
|
for (auto v:inputs) {
|
||||||
arrayList->push_back(v);
|
arrayList.push_back(v);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return arrayList;
|
return arrayList;
|
||||||
}
|
}
|
||||||
|
|
||||||
sd::ResultSet* sd::ops::DeclarableOp::execute(const sd::OpArgsHolder& holder, bool isInplace) {
|
sd::ResultSet sd::ops::DeclarableOp::execute(const sd::OpArgsHolder& holder, bool isInplace) {
|
||||||
// FIXME: add DArgs to OpArgsHolder
|
// FIXME: add DArgs to OpArgsHolder
|
||||||
return evaluate(holder.getInArrs(), holder.getTArgs(), holder.getIArgs(), holder.getBArgs(), std::vector<sd::DataType>(), isInplace);
|
return evaluate(holder.getInArrs(), holder.getTArgs(), holder.getIArgs(), holder.getBArgs(), std::vector<sd::DataType>(), isInplace);
|
||||||
}
|
}
|
||||||
|
|
|
@ -357,20 +357,20 @@ namespace sd {
|
||||||
return DeclarableOp::execute(block);
|
return DeclarableOp::execute(block);
|
||||||
}
|
}
|
||||||
|
|
||||||
sd::ResultSet* LegacyRandomOp::execute(sd::graph::RandomGenerator& rng, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs, bool isInplace) {
|
sd::ResultSet LegacyRandomOp::execute(sd::graph::RandomGenerator& rng, std::initializer_list<NDArray*> inputs, std::initializer_list<double> tArgs, std::initializer_list<int> iArgs, bool isInplace) {
|
||||||
std::vector<NDArray*> ins(inputs);
|
std::vector<NDArray*> ins(inputs);
|
||||||
std::vector<double> tas(tArgs);
|
std::vector<double> tas(tArgs);
|
||||||
std::vector<int> ias(iArgs);
|
std::vector<int> ias(iArgs);
|
||||||
return this->execute(rng, ins, tas, ias, isInplace);
|
return this->execute(rng, ins, tas, ias, isInplace);
|
||||||
}
|
}
|
||||||
|
|
||||||
sd::ResultSet* LegacyRandomOp::execute(sd::graph::RandomGenerator& rng, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs, bool isInplace) {
|
sd::ResultSet LegacyRandomOp::execute(sd::graph::RandomGenerator& rng, std::vector<NDArray*>& inputs, std::vector<double>& tArgs, std::vector<int>& iArgs, bool isInplace) {
|
||||||
VariableSpace variableSpace;
|
VariableSpace variableSpace;
|
||||||
auto arrayList = new ResultSet();
|
ResultSet arrayList;
|
||||||
//ResultSet arrayList;
|
//ResultSet arrayList;
|
||||||
|
|
||||||
if (isInplace)
|
if (isInplace)
|
||||||
arrayList->setNonRemovable();
|
arrayList.setNonRemovable();
|
||||||
|
|
||||||
int cnt = -1;
|
int cnt = -1;
|
||||||
std::vector<int> in;
|
std::vector<int> in;
|
||||||
|
@ -398,7 +398,7 @@ namespace sd {
|
||||||
block.getIArguments()->emplace_back(iArgs.at(e));
|
block.getIArguments()->emplace_back(iArgs.at(e));
|
||||||
|
|
||||||
Nd4jStatus status = this->execute(&block);
|
Nd4jStatus status = this->execute(&block);
|
||||||
arrayList->setStatus(status);
|
arrayList.setStatus(status);
|
||||||
if (status != ND4J_STATUS_OK)
|
if (status != ND4J_STATUS_OK)
|
||||||
return arrayList;
|
return arrayList;
|
||||||
|
|
||||||
|
@ -410,9 +410,9 @@ namespace sd {
|
||||||
auto arr = var->getNDArray();
|
auto arr = var->getNDArray();
|
||||||
if (!arr->isAttached()) {
|
if (!arr->isAttached()) {
|
||||||
var->markRemovable(false);
|
var->markRemovable(false);
|
||||||
arrayList->push_back(arr);
|
arrayList.push_back(arr);
|
||||||
} else {
|
} else {
|
||||||
arrayList->push_back(arr->detach());
|
arrayList.push_back(arr->detach());
|
||||||
}
|
}
|
||||||
} else
|
} else
|
||||||
break;
|
break;
|
||||||
|
@ -423,4 +423,4 @@ namespace sd {
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template Nd4jStatus LegacyRandomOp::validateAndExecute_, (Context&), FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template Nd4jStatus LegacyRandomOp::validateAndExecute_, (Context&), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,9 +44,7 @@ TEST_F(AttentionTests, basic_dot_product_attention) {
|
||||||
|
|
||||||
sd::ops::dot_product_attention op;
|
sd::ops::dot_product_attention op;
|
||||||
auto result = op.evaluate({&queries, &keys, &values}, {1, 0});
|
auto result = op.evaluate({&queries, &keys, &values}, {1, 0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -72,9 +70,7 @@ TEST_F(AttentionTests, basic_dot_product_attention_with_weights) {
|
||||||
|
|
||||||
sd::ops::dot_product_attention op;
|
sd::ops::dot_product_attention op;
|
||||||
auto result = op.evaluate({&queries, &keys, &values}, {1, 1});
|
auto result = op.evaluate({&queries, &keys, &values}, {1, 1});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(AttentionTests, basic_dot_product_attention_with_mask) {
|
TEST_F(AttentionTests, basic_dot_product_attention_with_mask) {
|
||||||
|
@ -86,9 +82,7 @@ TEST_F(AttentionTests, basic_dot_product_attention_with_mask) {
|
||||||
|
|
||||||
sd::ops::dot_product_attention op;
|
sd::ops::dot_product_attention op;
|
||||||
auto result = op.evaluate({&queries, &keys, &values, &mask}, {1, 0});
|
auto result = op.evaluate({&queries, &keys, &values, &mask}, {1, 0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -118,9 +112,7 @@ TEST_F(AttentionTests, multi_head_input_dot_product_attention_with_mask) {
|
||||||
|
|
||||||
sd::ops::dot_product_attention op;
|
sd::ops::dot_product_attention op;
|
||||||
auto result = op.evaluate({&queries, &keys, &values, &mask}, {1, 0});
|
auto result = op.evaluate({&queries, &keys, &values, &mask}, {1, 0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -154,9 +146,7 @@ TEST_F(AttentionTests, basic_multi_head_dot_product_attention) {
|
||||||
|
|
||||||
sd::ops::multi_head_dot_product_attention op;
|
sd::ops::multi_head_dot_product_attention op;
|
||||||
auto result = op.evaluate({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo}, {1, 0});
|
auto result = op.evaluate({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo}, {1, 0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -198,9 +188,7 @@ TEST_F(AttentionTests, basic_multi_head_dot_product_attention_with_mask) {
|
||||||
|
|
||||||
sd::ops::multi_head_dot_product_attention op;
|
sd::ops::multi_head_dot_product_attention op;
|
||||||
auto result = op.evaluate({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo, &mask}, {1, 0});
|
auto result = op.evaluate({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo, &mask}, {1, 0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
|
|
@ -39,13 +39,11 @@ TEST_F(BackpropTests, Test_Add_1) {
|
||||||
sd::ops::add_bp op;
|
sd::ops::add_bp op;
|
||||||
auto result = op.evaluate({&x, &y, &e});
|
auto result = op.evaluate({&x, &y, &e});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto eps = result->at(0);
|
auto eps = result.at(0);
|
||||||
auto grad = result->at(1);
|
auto grad = result.at(1);
|
||||||
|
|
||||||
ASSERT_TRUE(x.isSameShape(eps));
|
ASSERT_TRUE(x.isSameShape(eps));
|
||||||
ASSERT_TRUE(y.isSameShape(grad));
|
ASSERT_TRUE(y.isSameShape(grad));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
|
@ -137,13 +137,12 @@ TEST_F(BooleanOpsTests, test_where_1) {
|
||||||
sd::ops::choose op;
|
sd::ops::choose op;
|
||||||
|
|
||||||
auto result = op.evaluate({&x, &y}, {3});
|
auto result = op.evaluate({&x, &y}, {3});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
//z->printIndexedBuffer("z");
|
//z->printIndexedBuffer("z");
|
||||||
|
|
||||||
ASSERT_EQ(e, *z);
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -48,9 +48,9 @@ TEST_F(BroadcastableOpsTests, Test_Add_1) {
|
||||||
sd::ops::add op;
|
sd::ops::add op;
|
||||||
auto result = op.evaluate({&x, &y});
|
auto result = op.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
//exp.printIndexedBuffer("E A");
|
//exp.printIndexedBuffer("E A");
|
||||||
//z->printIndexedBuffer("Z");
|
//z->printIndexedBuffer("Z");
|
||||||
|
@ -58,7 +58,6 @@ TEST_F(BroadcastableOpsTests, Test_Add_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -75,14 +74,12 @@ TEST_F(BroadcastableOpsTests, Test_Multiply_1) {
|
||||||
sd::ops::multiply op;
|
sd::ops::multiply op;
|
||||||
auto result = op.evaluate({&x, &y});
|
auto result = op.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -100,14 +97,13 @@ TEST_F(BroadcastableOpsTests, Test_SquaredSubtract_1) {
|
||||||
sd::ops::squaredsubtract op;
|
sd::ops::squaredsubtract op;
|
||||||
auto result = op.evaluate({&x, &y});
|
auto result = op.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -119,14 +115,12 @@ TEST_F(BroadcastableOpsTests, Test_ScalarBroadcast_1) {
|
||||||
sd::ops::subtract op;
|
sd::ops::subtract op;
|
||||||
auto result = op.evaluate({&x, &y});
|
auto result = op.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -138,14 +132,12 @@ TEST_F(BroadcastableOpsTests, Test_ScalarBroadcast_2) {
|
||||||
sd::ops::add op;
|
sd::ops::add op;
|
||||||
auto result = op.evaluate({&x, &y});
|
auto result = op.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -156,14 +148,12 @@ TEST_F(BroadcastableOpsTests, Test_Maximum_1) {
|
||||||
|
|
||||||
sd::ops::maximum op;
|
sd::ops::maximum op;
|
||||||
auto result = op.evaluate({&x, &row});
|
auto result = op.evaluate({&x, &row});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -174,15 +164,14 @@ TEST_F(BroadcastableOpsTests, Test_Minimum_1) {
|
||||||
|
|
||||||
sd::ops::minimum op;
|
sd::ops::minimum op;
|
||||||
auto result = op.evaluate({&x, &col});
|
auto result = op.evaluate({&x, &col});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -283,14 +272,13 @@ TEST_F(BroadcastableOpsTests, Test_Scalar_Add_1) {
|
||||||
|
|
||||||
sd::ops::add op;
|
sd::ops::add op;
|
||||||
auto result = op.evaluate({&x, &y});
|
auto result = op.evaluate({&x, &y});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -333,11 +321,9 @@ TEST_F(BroadcastableOpsTests, Test_Subtract_2) {
|
||||||
|
|
||||||
sd::ops::subtract op;
|
sd::ops::subtract op;
|
||||||
auto result = op.evaluate({&x, &y});
|
auto result = op.evaluate({&x, &y});
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(e.equalsTo(z));
|
ASSERT_TRUE(e.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(BroadcastableOpsTests, Test_Subtract_3) {
|
TEST_F(BroadcastableOpsTests, Test_Subtract_3) {
|
||||||
|
@ -511,13 +497,12 @@ TEST_F(BroadcastableOpsTests, Test_Multiply_7) {
|
||||||
|
|
||||||
sd::ops::multiply op;
|
sd::ops::multiply op;
|
||||||
auto result = op.evaluate({&x, &y});
|
auto result = op.evaluate({&x, &y});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(e.equalsTo(z));
|
ASSERT_TRUE(e.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(BroadcastableOpsTests, Test_Multiply_8) {
|
TEST_F(BroadcastableOpsTests, Test_Multiply_8) {
|
||||||
|
@ -527,13 +512,11 @@ TEST_F(BroadcastableOpsTests, Test_Multiply_8) {
|
||||||
|
|
||||||
sd::ops::multiply op;
|
sd::ops::multiply op;
|
||||||
auto result = op.evaluate({&x, &y});
|
auto result = op.evaluate({&x, &y});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(e.equalsTo(z));
|
ASSERT_TRUE(e.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -606,14 +589,12 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_3) {
|
||||||
sd::ops::maximum op;
|
sd::ops::maximum op;
|
||||||
auto result = op.evaluate({&x, &y});
|
auto result = op.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(e.isSameShape(z));
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
ASSERT_TRUE(e.equalsTo(*z));
|
ASSERT_TRUE(e.equalsTo(*z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(BroadcastableOpsTests, broadcast_empty_4) {
|
TEST_F(BroadcastableOpsTests, broadcast_empty_4) {
|
||||||
|
@ -625,14 +606,13 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_4) {
|
||||||
sd::ops::maximum op;
|
sd::ops::maximum op;
|
||||||
auto result = op.evaluate({&x, &y});
|
auto result = op.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(e.isSameShape(z));
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
ASSERT_TRUE(e.equalsTo(*z));
|
ASSERT_TRUE(e.equalsTo(*z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(BroadcastableOpsTests, broadcast_empty_5) {
|
TEST_F(BroadcastableOpsTests, broadcast_empty_5) {
|
||||||
|
@ -644,14 +624,13 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_5) {
|
||||||
sd::ops::realdiv op;
|
sd::ops::realdiv op;
|
||||||
auto result = op.evaluate({&x, &y});
|
auto result = op.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(e.isSameShape(z));
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
ASSERT_TRUE(e.equalsTo(*z));
|
ASSERT_TRUE(e.equalsTo(*z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(BroadcastableOpsTests, broadcast_empty_6) {
|
TEST_F(BroadcastableOpsTests, broadcast_empty_6) {
|
||||||
|
@ -663,14 +642,13 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_6) {
|
||||||
sd::ops::realdiv op;
|
sd::ops::realdiv op;
|
||||||
auto result = op.evaluate({&x, &y});
|
auto result = op.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(e.isSameShape(z));
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
ASSERT_TRUE(e.equalsTo(*z));
|
ASSERT_TRUE(e.equalsTo(*z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(BroadcastableOpsTests, broadcast_empty_7) {
|
TEST_F(BroadcastableOpsTests, broadcast_empty_7) {
|
||||||
|
@ -682,14 +660,12 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_7) {
|
||||||
sd::ops::realdiv op;
|
sd::ops::realdiv op;
|
||||||
auto result = op.evaluate({&x, &y});
|
auto result = op.evaluate({&x, &y});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(e.isSameShape(z));
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
ASSERT_TRUE(e.equalsTo(*z));
|
ASSERT_TRUE(e.equalsTo(*z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -718,15 +694,13 @@ TEST_F(BroadcastableOpsTests, broadcast_bool_empty_2) {
|
||||||
sd::ops::greater op;
|
sd::ops::greater op;
|
||||||
auto result = op.evaluate({&x, &y});
|
auto result = op.evaluate({&x, &y});
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
// z->printShapeInfo("z");
|
// z->printShapeInfo("z");
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
ASSERT_TRUE(e.isSameShape(z));
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
ASSERT_TRUE(e.equalsTo(*z));
|
ASSERT_TRUE(e.equalsTo(*z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(BroadcastableOpsTests, broadcast_bool_1) {
|
TEST_F(BroadcastableOpsTests, broadcast_bool_1) {
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -48,9 +48,7 @@ TEST_F(DataTypesValidationTests, Basic_Test_1) {
|
||||||
sd::ops::conv2d op;
|
sd::ops::conv2d op;
|
||||||
auto result = op.evaluate({&input, &weights}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0});
|
auto result = op.evaluate({&input, &weights}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_VALIDATION, result->status());
|
ASSERT_EQ(ND4J_STATUS_VALIDATION, result.status());
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DataTypesValidationTests, Basic_Test_2) {
|
TEST_F(DataTypesValidationTests, Basic_Test_2) {
|
||||||
|
@ -63,13 +61,12 @@ TEST_F(DataTypesValidationTests, Basic_Test_2) {
|
||||||
|
|
||||||
sd::ops::conv2d op;
|
sd::ops::conv2d op;
|
||||||
auto result = op.evaluate({&input, &weights}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0});
|
auto result = op.evaluate({&input, &weights}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -47,13 +47,11 @@ TEST_F(DeclarableOpsTests16, scatter_upd_1) {
|
||||||
|
|
||||||
sd::ops::scatter_upd op;
|
sd::ops::scatter_upd op;
|
||||||
auto result = op.evaluate({ &x, &y, &w });
|
auto result = op.evaluate({ &x, &y, &w });
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_EQ(e, *z);
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests16, scatter_upd_2) {
|
TEST_F(DeclarableOpsTests16, scatter_upd_2) {
|
||||||
|
@ -67,13 +65,11 @@ TEST_F(DeclarableOpsTests16, scatter_upd_2) {
|
||||||
|
|
||||||
sd::ops::scatter_upd op;
|
sd::ops::scatter_upd op;
|
||||||
auto result = op.evaluate({ &x, &indices, &updates });
|
auto result = op.evaluate({ &x, &indices, &updates });
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_EQ(e, *z);
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests16, scatter_upd_3) {
|
TEST_F(DeclarableOpsTests16, scatter_upd_3) {
|
||||||
|
@ -136,13 +132,11 @@ TEST_F(DeclarableOpsTests16, test_hamming_distance_1) {
|
||||||
|
|
||||||
sd::ops::bits_hamming_distance op;
|
sd::ops::bits_hamming_distance op;
|
||||||
auto result = op.evaluate({ &x, &y });
|
auto result = op.evaluate({ &x, &y });
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_EQ(e, *z);
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests16, test_knn_mindistance_1) {
|
TEST_F(DeclarableOpsTests16, test_knn_mindistance_1) {
|
||||||
|
@ -167,10 +161,8 @@ TEST_F(DeclarableOpsTests16, test_empty_cast_1) {
|
||||||
|
|
||||||
sd::ops::cast op;
|
sd::ops::cast op;
|
||||||
auto result = op.evaluate({&x}, {10});
|
auto result = op.evaluate({&x}, {10});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
ASSERT_EQ(e, *result->at(0));
|
ASSERT_EQ(e, *result.at(0));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests16, test_range_1) {
|
TEST_F(DeclarableOpsTests16, test_range_1) {
|
||||||
|
@ -717,8 +709,6 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_1) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_2) {
|
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_2) {
|
||||||
|
|
||||||
auto rgb = NDArrayFactory::create<float>('c', { 5, 3, 4 },
|
auto rgb = NDArrayFactory::create<float>('c', { 5, 3, 4 },
|
||||||
|
@ -767,7 +757,6 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_2) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_3) {
|
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_3) {
|
||||||
|
|
||||||
auto rgb = NDArrayFactory::create<float>('c', { 4, 3 },
|
auto rgb = NDArrayFactory::create<float>('c', { 4, 3 },
|
||||||
|
@ -798,7 +787,6 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_3) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_4) {
|
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_4) {
|
||||||
|
|
||||||
auto rgb = NDArrayFactory::create<float>('c', { 3, 4 },
|
auto rgb = NDArrayFactory::create<float>('c', { 3, 4 },
|
||||||
|
@ -826,7 +814,6 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_4) {
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
ASSERT_TRUE(expected.equalsTo(actual));
|
ASSERT_TRUE(expected.equalsTo(actual));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -850,7 +837,6 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_5) {
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
ASSERT_TRUE(expected.equalsTo(actual));
|
ASSERT_TRUE(expected.equalsTo(actual));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_6) {
|
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_6) {
|
||||||
|
@ -891,8 +877,6 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_6) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_1) {
|
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_1) {
|
||||||
|
|
||||||
auto yiqs = NDArrayFactory::create<float>('c', { 5, 4, 3 }, {
|
auto yiqs = NDArrayFactory::create<float>('c', { 5, 4, 3 }, {
|
||||||
|
@ -937,8 +921,6 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_1) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_2) {
|
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_2) {
|
||||||
|
|
||||||
auto yiqs = NDArrayFactory::create<float>('c', { 5, 3, 4 }, {
|
auto yiqs = NDArrayFactory::create<float>('c', { 5, 3, 4 }, {
|
||||||
|
@ -983,7 +965,6 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_2) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_3) {
|
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_3) {
|
||||||
|
|
||||||
auto yiqs = NDArrayFactory::create<float>('c', { 4, 3 }, {
|
auto yiqs = NDArrayFactory::create<float>('c', { 4, 3 }, {
|
||||||
|
@ -1010,7 +991,6 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_3) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_4) {
|
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_4) {
|
||||||
|
|
||||||
auto yiqs = NDArrayFactory::create<float>('c', { 3, 4 }, {
|
auto yiqs = NDArrayFactory::create<float>('c', { 3, 4 }, {
|
||||||
|
@ -1037,8 +1017,6 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_4) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_5) {
|
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_5) {
|
||||||
|
|
||||||
auto yiqs = NDArrayFactory::create<float>('c', { 3 }, {
|
auto yiqs = NDArrayFactory::create<float>('c', { 3 }, {
|
||||||
|
@ -1061,7 +1039,6 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_5) {
|
||||||
#endif
|
#endif
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
ASSERT_TRUE(expected.equalsTo(actual));
|
ASSERT_TRUE(expected.equalsTo(actual));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_6) {
|
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_6) {
|
||||||
|
@ -1096,5 +1073,4 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_6) {
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
ASSERT_TRUE(expected.equalsTo(actual));
|
ASSERT_TRUE(expected.equalsTo(actual));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,9 +49,7 @@ TEST_F(DeclarableOpsTests17, test_sparse_to_dense_1) {
|
||||||
|
|
||||||
sd::ops::compat_sparse_to_dense op;
|
sd::ops::compat_sparse_to_dense op;
|
||||||
auto result = op.evaluate({&ranges, &shape, &values, &def});
|
auto result = op.evaluate({&ranges, &shape, &values, &def});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests17, test_sparse_to_dense_2) {
|
TEST_F(DeclarableOpsTests17, test_sparse_to_dense_2) {
|
||||||
|
@ -64,9 +62,8 @@ TEST_F(DeclarableOpsTests17, test_sparse_to_dense_2) {
|
||||||
|
|
||||||
sd::ops::compat_sparse_to_dense op;
|
sd::ops::compat_sparse_to_dense op;
|
||||||
auto result = op.evaluate({&ranges, &shape, &values, &def});
|
auto result = op.evaluate({&ranges, &shape, &values, &def});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests17, test_compat_string_split_1) {
|
TEST_F(DeclarableOpsTests17, test_compat_string_split_1) {
|
||||||
|
@ -78,11 +75,11 @@ TEST_F(DeclarableOpsTests17, test_compat_string_split_1) {
|
||||||
|
|
||||||
sd::ops::compat_string_split op;
|
sd::ops::compat_string_split op;
|
||||||
auto result = op.evaluate({&x, &delimiter});
|
auto result = op.evaluate({&x, &delimiter});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
ASSERT_EQ(2, result->size());
|
ASSERT_EQ(2, result.size());
|
||||||
|
|
||||||
auto z0 = result->at(0);
|
auto z0 = result.at(0);
|
||||||
auto z1 = result->at(1);
|
auto z1 = result.at(1);
|
||||||
|
|
||||||
ASSERT_TRUE(exp0.isSameShape(z0));
|
ASSERT_TRUE(exp0.isSameShape(z0));
|
||||||
ASSERT_TRUE(exp1.isSameShape(z1));
|
ASSERT_TRUE(exp1.isSameShape(z1));
|
||||||
|
@ -90,5 +87,4 @@ TEST_F(DeclarableOpsTests17, test_compat_string_split_1) {
|
||||||
ASSERT_EQ(exp0, *z0);
|
ASSERT_EQ(exp0, *z0);
|
||||||
ASSERT_EQ(exp1, *z1);
|
ASSERT_EQ(exp1, *z1);
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
|
@ -62,10 +62,8 @@ TEST_F(DeclarableOpsTests19, test_conv1d_bp_1) {
|
||||||
|
|
||||||
sd::ops::conv1d_bp op;
|
sd::ops::conv1d_bp op;
|
||||||
auto result = op.evaluate({&t, &u, &v}, {3, 2, 0, 1, 2,0});
|
auto result = op.evaluate({&t, &u, &v}, {3, 2, 0, 1, 2,0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests19, test_squeeze_1) {
|
TEST_F(DeclarableOpsTests19, test_squeeze_1) {
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -51,14 +51,12 @@ TEST_F(DeclarableOpsTestsCuda1, Test_CHOOSE_SCALAR_LARGE) {
|
||||||
sd::ops::choose op;
|
sd::ops::choose op;
|
||||||
//greater than test
|
//greater than test
|
||||||
auto result = op.evaluate({&x}, {0.0},{3});
|
auto result = op.evaluate({&x}, {0.0},{3});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(1);
|
auto z = result.at(1);
|
||||||
|
|
||||||
ASSERT_EQ(148,z->e<double>(0));
|
ASSERT_EQ(148,z->e<double>(0));
|
||||||
//ASSERT_TRUE(exp.isSameShape(z));
|
//ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
|
|
@ -67,9 +67,9 @@ TEST_F(EmptyTests, Test_Concat_1) {
|
||||||
|
|
||||||
sd::ops::concat op;
|
sd::ops::concat op;
|
||||||
auto result = op.evaluate({empty, vector}, {}, {0});
|
auto result = op.evaluate({empty, vector}, {}, {0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
// z->printShapeInfo("z shape");
|
// z->printShapeInfo("z shape");
|
||||||
// z->printIndexedBuffer("z buffr");
|
// z->printIndexedBuffer("z buffr");
|
||||||
|
@ -78,7 +78,6 @@ TEST_F(EmptyTests, Test_Concat_1) {
|
||||||
|
|
||||||
delete empty;
|
delete empty;
|
||||||
delete vector;
|
delete vector;
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -92,9 +91,9 @@ TEST_F(EmptyTests, Test_Concat_2) {
|
||||||
|
|
||||||
sd::ops::concat op;
|
sd::ops::concat op;
|
||||||
auto result = op.evaluate({empty, scalar1, scalar2}, {}, {0});
|
auto result = op.evaluate({empty, scalar1, scalar2}, {}, {0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
// z->printShapeInfo("z shape");
|
// z->printShapeInfo("z shape");
|
||||||
// z->printIndexedBuffer("z buffr");
|
// z->printIndexedBuffer("z buffr");
|
||||||
|
@ -104,7 +103,6 @@ TEST_F(EmptyTests, Test_Concat_2) {
|
||||||
delete empty;
|
delete empty;
|
||||||
delete scalar1;
|
delete scalar1;
|
||||||
delete scalar2;
|
delete scalar2;
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(EmptyTests, Test_Concat_3) {
|
TEST_F(EmptyTests, Test_Concat_3) {
|
||||||
|
@ -117,13 +115,12 @@ TEST_F(EmptyTests, Test_Concat_3) {
|
||||||
|
|
||||||
sd::ops::concat op;
|
sd::ops::concat op;
|
||||||
auto result = op.evaluate({&empty, &scalar1, &scalar2}, {}, {0});
|
auto result = op.evaluate({&empty, &scalar1, &scalar2}, {}, {0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_EQ(exp, *z);
|
ASSERT_EQ(exp, *z);
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(EmptyTests, Test_Concat_4) {
|
TEST_F(EmptyTests, Test_Concat_4) {
|
||||||
|
@ -136,13 +133,11 @@ TEST_F(EmptyTests, Test_Concat_4) {
|
||||||
|
|
||||||
sd::ops::concat op;
|
sd::ops::concat op;
|
||||||
auto result = op.evaluate({&scalar1, &empty, &scalar2}, {}, {0});
|
auto result = op.evaluate({&scalar1, &empty, &scalar2}, {}, {0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_EQ(exp, *z);
|
ASSERT_EQ(exp, *z);
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(EmptyTests, Test_Reshape_1) {
|
TEST_F(EmptyTests, Test_Reshape_1) {
|
||||||
|
@ -153,12 +148,11 @@ TEST_F(EmptyTests, Test_Reshape_1) {
|
||||||
sd::ops::reshape op;
|
sd::ops::reshape op;
|
||||||
auto result = op.evaluate({&vector, empty}, {}, {});
|
auto result = op.evaluate({&vector, empty}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
ASSERT_EQ(exp, *result->at(0));
|
ASSERT_EQ(exp, *result.at(0));
|
||||||
|
|
||||||
delete empty;
|
delete empty;
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(EmptyTests, Test_Reshape_3) {
|
TEST_F(EmptyTests, Test_Reshape_3) {
|
||||||
|
@ -168,14 +162,13 @@ TEST_F(EmptyTests, Test_Reshape_3) {
|
||||||
|
|
||||||
sd::ops::reshape op;
|
sd::ops::reshape op;
|
||||||
auto result = op.evaluate({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(e.isSameShape(z));
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
ASSERT_EQ(e, *z);
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(EmptyTests, Test_dup_1) {
|
TEST_F(EmptyTests, Test_dup_1) {
|
||||||
|
@ -197,12 +190,11 @@ TEST_F(EmptyTests, test_empty_scatter_1) {
|
||||||
|
|
||||||
sd::ops::scatter_upd op;
|
sd::ops::scatter_upd op;
|
||||||
auto result = op.evaluate({&x, &indices, &updates}, {}, {}, {true});
|
auto result = op.evaluate({&x, &indices, &updates}, {}, {}, {true});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
ASSERT_EQ(x, *z);
|
ASSERT_EQ(x, *z);
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(EmptyTests, test_empty_scatter_2) {
|
TEST_F(EmptyTests, test_empty_scatter_2) {
|
||||||
|
@ -288,17 +280,15 @@ TEST_F(EmptyTests, test_empty_reshape_1) {
|
||||||
|
|
||||||
sd::ops::reshape op;
|
sd::ops::reshape op;
|
||||||
auto result0 = op.evaluate({&x0, &shape0}, {}, {});
|
auto result0 = op.evaluate({&x0, &shape0}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result0->status());
|
ASSERT_EQ(Status::OK(), result0.status());
|
||||||
auto z0 = result0->at(0);
|
auto z0 = result0.at(0);
|
||||||
ASSERT_EQ(e0, *z0);
|
ASSERT_EQ(e0, *z0);
|
||||||
|
|
||||||
auto result1 = op.evaluate({&x1, &shape1}, {}, {});
|
auto result1 = op.evaluate({&x1, &shape1}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result1->status());
|
ASSERT_EQ(Status::OK(), result1.status());
|
||||||
auto z1 = result1->at(0);
|
auto z1 = result1.at(0);
|
||||||
ASSERT_EQ(e1, *z1);
|
ASSERT_EQ(e1, *z1);
|
||||||
|
|
||||||
delete result0;
|
|
||||||
delete result1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -309,12 +299,11 @@ TEST_F(EmptyTests, test_empty_matmul_1) {
|
||||||
|
|
||||||
sd::ops::matmul op;
|
sd::ops::matmul op;
|
||||||
auto result = op.evaluate({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
ASSERT_EQ(e, *z);
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(EmptyTests, test_empty_matmul_2) {
|
TEST_F(EmptyTests, test_empty_matmul_2) {
|
||||||
|
@ -324,10 +313,8 @@ TEST_F(EmptyTests, test_empty_matmul_2) {
|
||||||
|
|
||||||
sd::ops::matmul op;
|
sd::ops::matmul op;
|
||||||
auto result = op.evaluate({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
ASSERT_EQ(e, *z);
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1889,20 +1889,19 @@ TEST_F(HelpersTests1, OpArgsHolder_test3) {
|
||||||
OpArgsHolder holderFF({&input}, {}, {2, 3});
|
OpArgsHolder holderFF({&input}, {}, {2, 3});
|
||||||
sd::ops::tile opFF; // the kind of op doesn't matter, we simply check here whether op.execute() works with OpArgsHolder correctly
|
sd::ops::tile opFF; // the kind of op doesn't matter, we simply check here whether op.execute() works with OpArgsHolder correctly
|
||||||
auto results = opFF.execute(holderFF);
|
auto results = opFF.execute(holderFF);
|
||||||
auto tiled = results->at(0);
|
auto tiled = results.at(0);
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results.status());
|
||||||
ASSERT_TRUE(exp.isSameShape(tiled));
|
ASSERT_TRUE(exp.isSameShape(tiled));
|
||||||
ASSERT_TRUE(exp.equalsTo(tiled));
|
ASSERT_TRUE(exp.equalsTo(tiled));
|
||||||
delete results;
|
|
||||||
|
|
||||||
OpArgsHolder holderBP = holderFF.createArgsHolderForBP({&gradO}, true);
|
OpArgsHolder holderBP = holderFF.createArgsHolderForBP({&gradO}, true);
|
||||||
sd::ops::tile_bp opBP;
|
sd::ops::tile_bp opBP;
|
||||||
results = opBP.execute(holderBP);
|
results = opBP.execute(holderBP);
|
||||||
auto gradI = results->at(0);
|
auto gradI = results.at(0);
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results.status());
|
||||||
ASSERT_TRUE(gradIExp.isSameShape(gradI));
|
ASSERT_TRUE(gradIExp.isSameShape(gradI));
|
||||||
ASSERT_TRUE(gradIExp.equalsTo(gradI));
|
ASSERT_TRUE(gradIExp.equalsTo(gradI));
|
||||||
delete results;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -47,13 +47,13 @@ TEST_F(IndexingTests, StridedSlice_1) {
|
||||||
sd::ops::strided_slice op;
|
sd::ops::strided_slice op;
|
||||||
|
|
||||||
auto result = op.evaluate({&x, &begin, &end, &strides}, {}, {0,0,0,0,0}); //, 2,2,0, 3,3,3, 1,1,1});
|
auto result = op.evaluate({&x, &begin, &end, &strides}, {}, {0,0,0,0,0}); //, 2,2,0, 3,3,3, 1,1,1});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -66,14 +66,14 @@ TEST_F(IndexingTests, StridedSlice_2) {
|
||||||
sd::ops::strided_slice op;
|
sd::ops::strided_slice op;
|
||||||
|
|
||||||
auto result = op.evaluate({&x}, {}, {0,0,0,0,0, 3,2,0, 5,5,3, 1,1,1});
|
auto result = op.evaluate({&x}, {}, {0,0,0,0,0, 3,2,0, 5,5,3, 1,1,1});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -86,14 +86,14 @@ TEST_F(IndexingTests, StridedSlice_3) {
|
||||||
sd::ops::strided_slice op;
|
sd::ops::strided_slice op;
|
||||||
|
|
||||||
auto result = op.evaluate({&x}, {}, {0,0,0,0,0, 3,2,0, 5,5,3, 1,1,2});
|
auto result = op.evaluate({&x}, {}, {0,0,0,0,0, 3,2,0, 5,5,3, 1,1,2});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -109,15 +109,15 @@ TEST_F(IndexingTests, SimpleSlice_1) {
|
||||||
sd::ops::slice op;
|
sd::ops::slice op;
|
||||||
|
|
||||||
auto result = op.evaluate({&input}, {}, {1,0,0, 1,1,3});
|
auto result = op.evaluate({&input}, {}, {1,0,0, 1,1,3});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -135,15 +135,15 @@ TEST_F(IndexingTests, SimpleSlice_2) {
|
||||||
sd::ops::slice op;
|
sd::ops::slice op;
|
||||||
|
|
||||||
auto result = op.evaluate({&input}, {}, {1,0,0, 1,2,3});
|
auto result = op.evaluate({&input}, {}, {1,0,0, 1,2,3});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(IndexingTests, SimpleSlice_3) {
|
TEST_F(IndexingTests, SimpleSlice_3) {
|
||||||
|
@ -160,15 +160,15 @@ TEST_F(IndexingTests, SimpleSlice_3) {
|
||||||
sd::ops::slice op;
|
sd::ops::slice op;
|
||||||
|
|
||||||
auto result = op.evaluate({&input}, {}, {1,0,0, 2,1,3});
|
auto result = op.evaluate({&input}, {}, {1,0,0, 2,1,3});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(IndexingTests, SimpleSlice_4) {
|
TEST_F(IndexingTests, SimpleSlice_4) {
|
||||||
|
@ -180,14 +180,14 @@ TEST_F(IndexingTests, SimpleSlice_4) {
|
||||||
sd::ops::slice op;
|
sd::ops::slice op;
|
||||||
|
|
||||||
auto result = op.evaluate({&input, &start, &stop});
|
auto result = op.evaluate({&input, &start, &stop});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -204,16 +204,16 @@ TEST_F(IndexingTests, MaskedSlice_0) {
|
||||||
sd::ops::strided_slice op;
|
sd::ops::strided_slice op;
|
||||||
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,0, 1, 2, 1});
|
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,0, 1, 2, 1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
// z->printShapeInfo("z");
|
// z->printShapeInfo("z");
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -230,14 +230,14 @@ TEST_F(IndexingTests, MaskedSlice_00) {
|
||||||
sd::ops::strided_slice op;
|
sd::ops::strided_slice op;
|
||||||
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,0, 1, 1, 2, 3, 1, 1});
|
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,0, 1, 1, 2, 3, 1, 1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -254,16 +254,16 @@ TEST_F(IndexingTests, MaskedSlice_1) {
|
||||||
sd::ops::strided_slice op;
|
sd::ops::strided_slice op;
|
||||||
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,1, 1, 2, 1});
|
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,1, 1, 2, 1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
// z->printShapeInfo("z");
|
// z->printShapeInfo("z");
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(IndexingTests, MaskedSlice_2) {
|
TEST_F(IndexingTests, MaskedSlice_2) {
|
||||||
|
@ -275,14 +275,14 @@ TEST_F(IndexingTests, MaskedSlice_2) {
|
||||||
sd::ops::strided_slice op;
|
sd::ops::strided_slice op;
|
||||||
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,1, 1, 0, 0, 3, 3, 3, 1, 1, 1});
|
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,1, 1, 0, 0, 3, 3, 3, 1, 1, 1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -295,14 +295,14 @@ TEST_F(IndexingTests, MaskedSlice_3) {
|
||||||
sd::ops::strided_slice op;
|
sd::ops::strided_slice op;
|
||||||
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,2, 1, 0, 0, 3, 3, 3, 1, 1, 1});
|
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,2, 1, 0, 0, 3, 3, 3, 1, 1, 1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -315,15 +315,15 @@ TEST_F(IndexingTests, MaskedSlice_4) {
|
||||||
sd::ops::strided_slice op;
|
sd::ops::strided_slice op;
|
||||||
auto result = op.evaluate({&matrix}, {}, {0,0,0,0, 3, 1, 0, 0, 3, 3, 3, 1, 1, 1});
|
auto result = op.evaluate({&matrix}, {}, {0,0,0,0, 3, 1, 0, 0, 3, 3, 3, 1, 1, 1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(IndexingTests, Live_Slice_1) {
|
TEST_F(IndexingTests, Live_Slice_1) {
|
||||||
|
@ -338,16 +338,16 @@ TEST_F(IndexingTests, Live_Slice_1) {
|
||||||
sd::ops::strided_slice op;
|
sd::ops::strided_slice op;
|
||||||
auto result = op.evaluate({&matrix, &begin, &end, &stride}, {}, {0,0,0,0,3});
|
auto result = op.evaluate({&matrix, &begin, &end, &stride}, {}, {0,0,0,0,3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
// z->printShapeInfo("z shape");
|
// z->printShapeInfo("z shape");
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -361,14 +361,14 @@ TEST_F(IndexingTests, Test_StridedSlice_1) {
|
||||||
sd::ops::strided_slice op;
|
sd::ops::strided_slice op;
|
||||||
auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
|
auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(IndexingTests, Test_StridedSlice_2) {
|
TEST_F(IndexingTests, Test_StridedSlice_2) {
|
||||||
|
@ -381,16 +381,16 @@ TEST_F(IndexingTests, Test_StridedSlice_2) {
|
||||||
sd::ops::strided_slice op;
|
sd::ops::strided_slice op;
|
||||||
auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
|
auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
// z->printIndexedBuffer("Z");
|
// z->printIndexedBuffer("Z");
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -404,14 +404,14 @@ TEST_F(IndexingTests, Test_StridedSlice_3) {
|
||||||
sd::ops::strided_slice op;
|
sd::ops::strided_slice op;
|
||||||
auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
|
auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -426,16 +426,16 @@ TEST_F(IndexingTests, Test_StridedSlice_4) {
|
||||||
auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
|
auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
|
||||||
// auto result = op.execute({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1, 0, 1, 1});
|
// auto result = op.execute({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1, 0, 1, 1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
//z->printIndexedBuffer("Z");
|
//z->printIndexedBuffer("Z");
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(IndexingTests, Test_Subarray_Strided_1) {
|
TEST_F(IndexingTests, Test_Subarray_Strided_1) {
|
||||||
|
@ -458,13 +458,13 @@ TEST_F(IndexingTests, MaskedSlice_5) {
|
||||||
sd::ops::strided_slice<float> op;
|
sd::ops::strided_slice<float> op;
|
||||||
auto result = op.execute({&matrix}, {}, {0,0,0,0,2, 1, 0, 0, 3, 3, 3});
|
auto result = op.execute({&matrix}, {}, {0,0,0,0,2, 1, 0, 0, 3, 3, 3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
*/
|
*/
|
|
@ -64,13 +64,13 @@ TEST_F(LegacyOpsTests, TransformTests_2) {
|
||||||
sd::ops::LegacyTransformSameOp op(transform::Neg); // Neg
|
sd::ops::LegacyTransformSameOp op(transform::Neg); // Neg
|
||||||
auto result = op.evaluate({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(1, result->size());
|
ASSERT_EQ(1, result.size());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(LegacyOpsTests, Reciprocal_1) {
|
TEST_F(LegacyOpsTests, Reciprocal_1) {
|
||||||
|
@ -121,12 +121,12 @@ TEST_F(LegacyOpsTests, PWT_Tests_2) {
|
||||||
sd::ops::LegacyPairwiseTransformOp op(pairwise::Multiply); // Multiply
|
sd::ops::LegacyPairwiseTransformOp op(pairwise::Multiply); // Multiply
|
||||||
auto result = op.evaluate({&x, &y}, {}, {});
|
auto result = op.evaluate({&x, &y}, {}, {});
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
//z->printBuffer("Z");
|
//z->printBuffer("Z");
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(LegacyOpsTests, Scalar_Test_1) {
|
TEST_F(LegacyOpsTests, Scalar_Test_1) {
|
||||||
|
@ -154,10 +154,10 @@ TEST_F(LegacyOpsTests, Scalar_Test_2) {
|
||||||
sd::ops::LegacyScalarOp op(scalar::Add, y);
|
sd::ops::LegacyScalarOp op(scalar::Add, y);
|
||||||
auto result = op.evaluate({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -169,14 +169,14 @@ TEST_F(LegacyOpsTests, ReduceTests_1) {
|
||||||
|
|
||||||
auto result = op.evaluate({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(1, result->size());
|
ASSERT_EQ(1, result.size());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
// z->printBuffer("ReduceTest1");
|
// z->printBuffer("ReduceTest1");
|
||||||
ASSERT_TRUE(z->isScalar());
|
ASSERT_TRUE(z->isScalar());
|
||||||
ASSERT_NEAR(x.sumNumber().e<float>(0), z->e<float>(0), 1e-5f);
|
ASSERT_NEAR(x.sumNumber().e<float>(0), z->e<float>(0), 1e-5f);
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -188,16 +188,16 @@ TEST_F(LegacyOpsTests, ReduceTests_2) {
|
||||||
auto axis = NDArrayFactory::create<Nd4jLong>('c', {1}, {1});
|
auto axis = NDArrayFactory::create<Nd4jLong>('c', {1}, {1});
|
||||||
auto result = op.evaluate({&x, &axis}, {}, {});
|
auto result = op.evaluate({&x, &axis}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(1, result->size());
|
ASSERT_EQ(1, result.size());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
auto exp = x.reduceAlongDimension(reduce::Sum, {1});
|
auto exp = x.reduceAlongDimension(reduce::Sum, {1});
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -209,15 +209,15 @@ TEST_F(LegacyOpsTests, ReduceTests_3) {
|
||||||
|
|
||||||
sd::ops::LegacyReduceSameOp op(reduce::Sum);
|
sd::ops::LegacyReduceSameOp op(reduce::Sum);
|
||||||
auto result = op.evaluate({&x, &indices}, {}, {});
|
auto result = op.evaluate({&x, &indices}, {}, {});
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
auto exp = x.reduceAlongDimension(reduce::Sum,{1});
|
auto exp = x.reduceAlongDimension(reduce::Sum,{1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -229,16 +229,16 @@ TEST_F(LegacyOpsTests, ReduceTests_4) {
|
||||||
|
|
||||||
sd::ops::LegacyReduceSameOp op(reduce::Sum);
|
sd::ops::LegacyReduceSameOp op(reduce::Sum);
|
||||||
auto result = op.evaluate({&x, &indices}, {}, {}, {true});
|
auto result = op.evaluate({&x, &indices}, {}, {}, {true});
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
auto exp = x.reduceAlongDimension(reduce::Sum, {1}, true);
|
auto exp = x.reduceAlongDimension(reduce::Sum, {1}, true);
|
||||||
// indices.printShapeInfo("Indices shape");
|
// indices.printShapeInfo("Indices shape");
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
// z->printIndexedBuffer("Output reduce 4");
|
// z->printIndexedBuffer("Output reduce 4");
|
||||||
// exp.printIndexedBuffer("Expected reduce 4");
|
// exp.printIndexedBuffer("Expected reduce 4");
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(LegacyOpsTests, ReduceTests_5) {
|
TEST_F(LegacyOpsTests, ReduceTests_5) {
|
||||||
|
@ -249,14 +249,14 @@ TEST_F(LegacyOpsTests, ReduceTests_5) {
|
||||||
|
|
||||||
auto result = op.evaluate({&x});
|
auto result = op.evaluate({&x});
|
||||||
|
|
||||||
ASSERT_EQ(1, result->size());
|
ASSERT_EQ(1, result.size());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
// z->printBuffer("ReduceTest1");
|
// z->printBuffer("ReduceTest1");
|
||||||
ASSERT_TRUE(z->isScalar());
|
ASSERT_TRUE(z->isScalar());
|
||||||
ASSERT_NEAR(x.meanNumber().e<float>(0), z->e<float>(0), 1e-5f);
|
ASSERT_NEAR(x.meanNumber().e<float>(0), z->e<float>(0), 1e-5f);
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -268,16 +268,16 @@ TEST_F(LegacyOpsTests, ReduceTests_6) {
|
||||||
|
|
||||||
auto result = op.evaluate({&x, &axis}, {}, {});
|
auto result = op.evaluate({&x, &axis}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(1, result->size());
|
ASSERT_EQ(1, result.size());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
auto exp = x.reduceAlongDimension(reduce::Mean, {1});
|
auto exp = x.reduceAlongDimension(reduce::Mean, {1});
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -289,15 +289,15 @@ TEST_F(LegacyOpsTests, ReduceTests_7) {
|
||||||
|
|
||||||
sd::ops::LegacyReduceFloatOp op(reduce::Mean);
|
sd::ops::LegacyReduceFloatOp op(reduce::Mean);
|
||||||
auto result = op.evaluate({&x, &indices}, {}, {});
|
auto result = op.evaluate({&x, &indices}, {}, {});
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
auto exp = x.reduceAlongDimension(reduce::Mean,{1});
|
auto exp = x.reduceAlongDimension(reduce::Mean,{1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -309,17 +309,17 @@ TEST_F(LegacyOpsTests, ReduceTests_8) {
|
||||||
|
|
||||||
sd::ops::LegacyReduceFloatOp op(reduce::Mean);
|
sd::ops::LegacyReduceFloatOp op(reduce::Mean);
|
||||||
auto result = op.evaluate({&x, &indices}, {}, {}, {true});
|
auto result = op.evaluate({&x, &indices}, {}, {}, {true});
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
auto exp = x.reduceAlongDimension(reduce::Mean, {1}, true);
|
auto exp = x.reduceAlongDimension(reduce::Mean, {1}, true);
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
// z->printIndexedBuffer("Reduce8 output");
|
// z->printIndexedBuffer("Reduce8 output");
|
||||||
// z->printShapeInfo("Reduce8 shape");
|
// z->printShapeInfo("Reduce8 shape");
|
||||||
// exp.printShapeInfo("Reduce8 expected shape");
|
// exp.printShapeInfo("Reduce8 expected shape");
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -331,14 +331,14 @@ TEST_F(LegacyOpsTests, IndexReduceTests_1) {
|
||||||
|
|
||||||
auto result = op.evaluate({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(1, result->size());
|
ASSERT_EQ(1, result.size());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(z->isScalar());
|
ASSERT_TRUE(z->isScalar());
|
||||||
ASSERT_EQ(24, z->e<int>(0));
|
ASSERT_EQ(24, z->e<int>(0));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -351,9 +351,9 @@ TEST_F(LegacyOpsTests, IndexReduceTests_2) {
|
||||||
|
|
||||||
auto result = op.evaluate({&x, &indices}, {}, {});
|
auto result = op.evaluate({&x, &indices}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(1, result->size());
|
ASSERT_EQ(1, result.size());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
// z->printIndexedBuffer("Hello indexreduce2");
|
// z->printIndexedBuffer("Hello indexreduce2");
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
//ASSERT_EQ(4, z->e<int>(0));
|
//ASSERT_EQ(4, z->e<int>(0));
|
||||||
|
@ -362,7 +362,7 @@ TEST_F(LegacyOpsTests, IndexReduceTests_2) {
|
||||||
//ASSERT_EQ(4, z->e<int>(3));
|
//ASSERT_EQ(4, z->e<int>(3));
|
||||||
//ASSERT_EQ(4, z->e<int>(4));
|
//ASSERT_EQ(4, z->e<int>(4));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(LegacyOpsTests, BroadcastingTests_1) {
|
TEST_F(LegacyOpsTests, BroadcastingTests_1) {
|
||||||
|
|
|
@ -39,7 +39,7 @@ TEST_F(ListOperationsTests, BasicTest_Write_1) {
|
||||||
|
|
||||||
auto result = op.execute(&list, {&x}, {}, {1});
|
auto result = op.execute(&list, {&x}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
ASSERT_EQ(1, list.elements());
|
ASSERT_EQ(1, list.elements());
|
||||||
|
|
||||||
|
@ -47,8 +47,8 @@ TEST_F(ListOperationsTests, BasicTest_Write_1) {
|
||||||
|
|
||||||
ASSERT_EQ(2, list.elements());
|
ASSERT_EQ(2, list.elements());
|
||||||
|
|
||||||
delete result;
|
|
||||||
delete result2;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ListOperationsTests, BasicTest_Stack_1) {
|
TEST_F(ListOperationsTests, BasicTest_Stack_1) {
|
||||||
|
@ -66,15 +66,15 @@ TEST_F(ListOperationsTests, BasicTest_Stack_1) {
|
||||||
|
|
||||||
auto result = op.execute(&list, {}, {}, {1});
|
auto result = op.execute(&list, {}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
// z->printShapeInfo();
|
// z->printShapeInfo();
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ListOperationsTests, BasicTest_UnStackList_1) {
|
TEST_F(ListOperationsTests, BasicTest_UnStackList_1) {
|
||||||
|
@ -93,10 +93,10 @@ TEST_F(ListOperationsTests, BasicTest_UnStackList_1) {
|
||||||
|
|
||||||
auto result = op.execute(&list, {&x}, {}, {0});
|
auto result = op.execute(&list, {&x}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
ASSERT_EQ(list.elements(), 10);
|
ASSERT_EQ(list.elements(), 10);
|
||||||
|
|
||||||
// auto z = result->at(0);
|
// auto z = result.at(0);
|
||||||
// z->printShapeInfo("The first of");
|
// z->printShapeInfo("The first of");
|
||||||
// ASSERT_TRUE(exp.isSameShape(z));
|
// ASSERT_TRUE(exp.isSameShape(z));
|
||||||
// ASSERT_TRUE(exp.equalsTo(z));
|
// ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
@ -107,7 +107,7 @@ TEST_F(ListOperationsTests, BasicTest_UnStackList_1) {
|
||||||
delete row;
|
delete row;
|
||||||
}
|
}
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//TEST_F(ListOperationsTests, BasicTest_UnStackList_2) {
|
//TEST_F(ListOperationsTests, BasicTest_UnStackList_2) {
|
||||||
|
@ -126,20 +126,20 @@ TEST_F(ListOperationsTests, BasicTest_UnStackList_1) {
|
||||||
//
|
//
|
||||||
// auto result = op.execute(nullptr, {&x}, {}, {0});
|
// auto result = op.execute(nullptr, {&x}, {}, {0});
|
||||||
//
|
//
|
||||||
// ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
// ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
// ASSERT_EQ(result->size(), 10);
|
// ASSERT_EQ(result->size(), 10);
|
||||||
//
|
//
|
||||||
// // auto z = result->at(0);
|
// // auto z = result.at(0);
|
||||||
//// z->printShapeInfo("The first of");
|
//// z->printShapeInfo("The first of");
|
||||||
//// ASSERT_TRUE(exp.isSameShape(z));
|
//// ASSERT_TRUE(exp.isSameShape(z));
|
||||||
//// ASSERT_TRUE(exp.equalsTo(z));
|
//// ASSERT_TRUE(exp.equalsTo(z));
|
||||||
// for (int e = 0; e < 10; e++) {
|
// for (int e = 0; e < 10; e++) {
|
||||||
// auto row = result->at(e);
|
// auto row = result.at(e);
|
||||||
// ASSERT_TRUE(row->equalsTo(tads->at(e)));
|
// ASSERT_TRUE(row->equalsTo(tads->at(e)));
|
||||||
// //list.write(e, row);
|
// //list.write(e, row);
|
||||||
// }
|
// }
|
||||||
//
|
//
|
||||||
// delete result;
|
//
|
||||||
// delete tads;
|
// delete tads;
|
||||||
//}
|
//}
|
||||||
|
|
||||||
|
@ -160,14 +160,14 @@ TEST_F(ListOperationsTests, BasicTest_Read_1) {
|
||||||
|
|
||||||
auto result = op.execute(&list, {}, {}, {4});
|
auto result = op.execute(&list, {}, {}, {4});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ListOperationsTests, BasicTest_Pick_1) {
|
TEST_F(ListOperationsTests, BasicTest_Pick_1) {
|
||||||
|
@ -192,14 +192,14 @@ TEST_F(ListOperationsTests, BasicTest_Pick_1) {
|
||||||
sd::ops::pick_list op;
|
sd::ops::pick_list op;
|
||||||
auto result = op.execute(&list, {}, {}, {1, 1, 3, 3});
|
auto result = op.execute(&list, {}, {}, {1, 1, 3, 3});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ListOperationsTests, BasicTest_Size_1) {
|
TEST_F(ListOperationsTests, BasicTest_Size_1) {
|
||||||
|
@ -217,14 +217,14 @@ TEST_F(ListOperationsTests, BasicTest_Size_1) {
|
||||||
|
|
||||||
auto result = op.execute(&list, {}, {}, {1});
|
auto result = op.execute(&list, {}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ListOperationsTests, BasicTest_Create_1) {
|
TEST_F(ListOperationsTests, BasicTest_Create_1) {
|
||||||
|
@ -235,12 +235,12 @@ TEST_F(ListOperationsTests, BasicTest_Create_1) {
|
||||||
|
|
||||||
auto result = op.execute(nullptr, {&matrix}, {}, {1, 1});
|
auto result = op.execute(nullptr, {&matrix}, {}, {1, 1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
// we return flow as well
|
// we return flow as well
|
||||||
ASSERT_EQ(1, result->size());
|
ASSERT_EQ(1, result.size());
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ListOperationsTests, BasicTest_Split_1) {
|
TEST_F(ListOperationsTests, BasicTest_Split_1) {
|
||||||
|
@ -283,7 +283,7 @@ TEST_F(ListOperationsTests, BasicTest_Split_1) {
|
||||||
|
|
||||||
sd::ops::split_list op;
|
sd::ops::split_list op;
|
||||||
auto result = op.execute(&list, {&matrix, &lengths}, {}, {});
|
auto result = op.execute(&list, {&matrix, &lengths}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
ASSERT_EQ(3, list.height());
|
ASSERT_EQ(3, list.height());
|
||||||
|
|
||||||
|
@ -296,7 +296,7 @@ TEST_F(ListOperationsTests, BasicTest_Split_1) {
|
||||||
ASSERT_TRUE(exp2.isSameShape(list.readRaw(2)));
|
ASSERT_TRUE(exp2.isSameShape(list.readRaw(2)));
|
||||||
ASSERT_TRUE(exp2.equalsTo(list.readRaw(2)));
|
ASSERT_TRUE(exp2.equalsTo(list.readRaw(2)));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ListOperationsTests, BasicTest_Scatter_1) {
|
TEST_F(ListOperationsTests, BasicTest_Scatter_1) {
|
||||||
|
@ -319,7 +319,7 @@ TEST_F(ListOperationsTests, BasicTest_Scatter_1) {
|
||||||
sd::ops::scatter_list op;
|
sd::ops::scatter_list op;
|
||||||
auto result = op.execute(&list, {&indices, &matrix, &s}, {}, {});
|
auto result = op.execute(&list, {&indices, &matrix, &s}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
for (int e = 0; e < 10; e++) {
|
for (int e = 0; e < 10; e++) {
|
||||||
auto row = tads.at(9 - e);
|
auto row = tads.at(9 - e);
|
||||||
|
@ -329,7 +329,7 @@ TEST_F(ListOperationsTests, BasicTest_Scatter_1) {
|
||||||
|
|
||||||
ASSERT_TRUE(chunk->equalsTo(row));
|
ASSERT_TRUE(chunk->equalsTo(row));
|
||||||
}
|
}
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ListOperationsTests, BasicTest_Clone_1) {
|
TEST_F(ListOperationsTests, BasicTest_Clone_1) {
|
||||||
|
@ -385,10 +385,10 @@ TEST_F(ListOperationsTests, BasicTest_Gather_1) {
|
||||||
sd::ops::gather_list op;
|
sd::ops::gather_list op;
|
||||||
auto result = op.execute(&list, {&indices}, {}, {});
|
auto result = op.execute(&list, {&indices}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
ASSERT_EQ(1, result->size());
|
ASSERT_EQ(1, result.size());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
|
||||||
|
@ -397,7 +397,7 @@ TEST_F(ListOperationsTests, BasicTest_Gather_1) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ListOperationsTests, GraphTests_Sequential_1) {
|
TEST_F(ListOperationsTests, GraphTests_Sequential_1) {
|
||||||
|
|
|
@ -134,13 +134,11 @@ TEST_F(MultiDataTypeTests, Basic_Test_7) {
|
||||||
|
|
||||||
sd::ops::add op;
|
sd::ops::add op;
|
||||||
auto result = op.evaluate({&x, &y});
|
auto result = op.evaluate({&x, &y});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_EQ(e, *z);
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -66,7 +66,7 @@ TEST_F(NlpTests, basic_sg_hs_test_1) {
|
||||||
|
|
||||||
sd::ops::skipgram op;
|
sd::ops::skipgram op;
|
||||||
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true);
|
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true);
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto row0 = syn0({0,1, 0,0}, true);
|
auto row0 = syn0({0,1, 0,0}, true);
|
||||||
auto row1 = syn1({1,2, 0,0}, true);
|
auto row1 = syn1({1,2, 0,0}, true);
|
||||||
|
@ -74,7 +74,7 @@ TEST_F(NlpTests, basic_sg_hs_test_1) {
|
||||||
ASSERT_EQ(exp0, row0);
|
ASSERT_EQ(exp0, row0);
|
||||||
ASSERT_EQ(exp1, row1);
|
ASSERT_EQ(exp1, row1);
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NlpTests, basic_sg_hs_test_2) {
|
TEST_F(NlpTests, basic_sg_hs_test_2) {
|
||||||
|
@ -107,7 +107,7 @@ TEST_F(NlpTests, basic_sg_hs_test_2) {
|
||||||
|
|
||||||
sd::ops::skipgram op;
|
sd::ops::skipgram op;
|
||||||
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true);
|
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true);
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto row0 = syn0({0,1, 0,0}, true);
|
auto row0 = syn0({0,1, 0,0}, true);
|
||||||
auto row1 = syn1({1,2, 0,0}, true);
|
auto row1 = syn1({1,2, 0,0}, true);
|
||||||
|
@ -117,7 +117,7 @@ TEST_F(NlpTests, basic_sg_hs_test_2) {
|
||||||
ASSERT_EQ(exp1, row1);
|
ASSERT_EQ(exp1, row1);
|
||||||
ASSERT_EQ(exp2, row2);
|
ASSERT_EQ(exp2, row2);
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NlpTests, basic_sg_hs_test_3) {
|
TEST_F(NlpTests, basic_sg_hs_test_3) {
|
||||||
|
@ -159,7 +159,7 @@ TEST_F(NlpTests, basic_sg_hs_test_3) {
|
||||||
sd::ops::skipgram op;
|
sd::ops::skipgram op;
|
||||||
auto result0 = op.evaluate({&target, &ngStarter, &indices0, &codes00, &syn00, &syn10, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true);
|
auto result0 = op.evaluate({&target, &ngStarter, &indices0, &codes00, &syn00, &syn10, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true);
|
||||||
auto result1 = op.evaluate({&target, &ngStarter, &indices1, &codes01, &syn01, &syn11, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true);
|
auto result1 = op.evaluate({&target, &ngStarter, &indices1, &codes01, &syn01, &syn11, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true);
|
||||||
ASSERT_EQ(Status::OK(), result0->status());
|
ASSERT_EQ(Status::OK(), result0.status());
|
||||||
|
|
||||||
auto row00 = syn00({0,1, 0,0}, true);
|
auto row00 = syn00({0,1, 0,0}, true);
|
||||||
auto row01 = syn01({0,1, 0,0}, true);
|
auto row01 = syn01({0,1, 0,0}, true);
|
||||||
|
@ -168,9 +168,6 @@ TEST_F(NlpTests, basic_sg_hs_test_3) {
|
||||||
|
|
||||||
ASSERT_EQ(row2, row1);
|
ASSERT_EQ(row2, row1);
|
||||||
ASSERT_EQ(row00, row01);
|
ASSERT_EQ(row00, row01);
|
||||||
|
|
||||||
delete result0;
|
|
||||||
delete result1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NlpTests, basic_sg_hs_ns_test_1) {
|
TEST_F(NlpTests, basic_sg_hs_ns_test_1) {
|
||||||
|
@ -192,9 +189,9 @@ TEST_F(NlpTests, basic_sg_hs_ns_test_1) {
|
||||||
|
|
||||||
sd::ops::skipgram op;
|
sd::ops::skipgram op;
|
||||||
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {3}, {false}, {}, true);
|
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {3}, {false}, {}, true);
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NlpTests, basic_sg_ns_test_1) {
|
TEST_F(NlpTests, basic_sg_ns_test_1) {
|
||||||
|
@ -227,14 +224,14 @@ TEST_F(NlpTests, basic_sg_ns_test_1) {
|
||||||
|
|
||||||
sd::ops::skipgram op;
|
sd::ops::skipgram op;
|
||||||
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {1, 1}, {false}, {}, true);
|
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {1, 1}, {false}, {}, true);
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto row0 = syn0({1,2, 0,0}, true);
|
auto row0 = syn0({1,2, 0,0}, true);
|
||||||
|
|
||||||
ASSERT_EQ(exp0, row0);
|
ASSERT_EQ(exp0, row0);
|
||||||
ASSERT_FALSE(syn1Neg2.equalsTo(syn1Neg, 1e-6));
|
ASSERT_FALSE(syn1Neg2.equalsTo(syn1Neg, 1e-6));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NlpTests, basic_cb_hs_test_1) {
|
TEST_F(NlpTests, basic_cb_hs_test_1) {
|
||||||
|
@ -269,7 +266,7 @@ TEST_F(NlpTests, basic_cb_hs_test_1) {
|
||||||
|
|
||||||
sd::ops::cbow op;
|
sd::ops::cbow op;
|
||||||
auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, {}, true);
|
auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, {}, true);
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto row_s0_0 = syn0({0,1, 0,0}, true);
|
auto row_s0_0 = syn0({0,1, 0,0}, true);
|
||||||
auto row_s0_1 = syn0({1,2, 0,0}, true);
|
auto row_s0_1 = syn0({1,2, 0,0}, true);
|
||||||
|
@ -287,7 +284,7 @@ TEST_F(NlpTests, basic_cb_hs_test_1) {
|
||||||
ASSERT_EQ(exp1, row_s1_5);
|
ASSERT_EQ(exp1, row_s1_5);
|
||||||
ASSERT_EQ(exp2, row_s1_6);
|
ASSERT_EQ(exp2, row_s1_6);
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NlpTests, basic_cb_ns_test_1) {
|
TEST_F(NlpTests, basic_cb_ns_test_1) {
|
||||||
|
@ -323,7 +320,7 @@ TEST_F(NlpTests, basic_cb_ns_test_1) {
|
||||||
|
|
||||||
sd::ops::cbow op;
|
sd::ops::cbow op;
|
||||||
auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {1, 2, 0}, {true}, {}, true);
|
auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {1, 2, 0}, {true}, {}, true);
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto row_s0_0 = syn0({0,1, 0,0}, true);
|
auto row_s0_0 = syn0({0,1, 0,0}, true);
|
||||||
auto row_s0_1 = syn0({1,2, 0,0}, true);
|
auto row_s0_1 = syn0({1,2, 0,0}, true);
|
||||||
|
@ -339,7 +336,7 @@ TEST_F(NlpTests, basic_cb_ns_test_1) {
|
||||||
ASSERT_EQ(exp0, row_s0_2);
|
ASSERT_EQ(exp0, row_s0_2);
|
||||||
ASSERT_EQ(exp2, row_s1_6);
|
ASSERT_EQ(exp2, row_s1_6);
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NlpTests, test_sg_hs_batch_1) {
|
TEST_F(NlpTests, test_sg_hs_batch_1) {
|
||||||
|
@ -372,7 +369,7 @@ TEST_F(NlpTests, test_sg_hs_batch_1) {
|
||||||
|
|
||||||
sd::ops::skipgram op;
|
sd::ops::skipgram op;
|
||||||
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false, true}, {}, true);
|
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false, true}, {}, true);
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto row0 = syn0({0,1, 0,0}, true);
|
auto row0 = syn0({0,1, 0,0}, true);
|
||||||
auto row1 = syn1({1,2, 0,0}, true);
|
auto row1 = syn1({1,2, 0,0}, true);
|
||||||
|
@ -382,7 +379,7 @@ TEST_F(NlpTests, test_sg_hs_batch_1) {
|
||||||
ASSERT_TRUE(exp1.equalsTo(row1, 1e-6));
|
ASSERT_TRUE(exp1.equalsTo(row1, 1e-6));
|
||||||
ASSERT_TRUE(exp2.equalsTo(row2, 1e-6));
|
ASSERT_TRUE(exp2.equalsTo(row2, 1e-6));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NlpTests, test_sg_ns_batch_1) {
|
TEST_F(NlpTests, test_sg_ns_batch_1) {
|
||||||
|
@ -416,9 +413,9 @@ TEST_F(NlpTests, test_sg_ns_batch_1) {
|
||||||
|
|
||||||
sd::ops::skipgram op;
|
sd::ops::skipgram op;
|
||||||
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {4, 5}, {false, true}, {}, true);
|
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {4, 5}, {false, true}, {}, true);
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NlpTests, test_cbow_hs_batch_1) {
|
TEST_F(NlpTests, test_cbow_hs_batch_1) {
|
||||||
|
@ -449,7 +446,7 @@ TEST_F(NlpTests, test_cbow_hs_batch_1) {
|
||||||
|
|
||||||
sd::ops::cbow op;
|
sd::ops::cbow op;
|
||||||
auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, {}, true);
|
auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, {}, true);
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto exp0 = NDArrayFactory::create<float>('c', {1, 10});
|
auto exp0 = NDArrayFactory::create<float>('c', {1, 10});
|
||||||
auto exp1 = NDArrayFactory::create<float>('c', {1, 10});
|
auto exp1 = NDArrayFactory::create<float>('c', {1, 10});
|
||||||
|
@ -473,6 +470,5 @@ TEST_F(NlpTests, test_cbow_hs_batch_1) {
|
||||||
ASSERT_EQ(exp1, row_s1_4);
|
ASSERT_EQ(exp1, row_s1_4);
|
||||||
ASSERT_EQ(exp1, row_s1_5);
|
ASSERT_EQ(exp1, row_s1_5);
|
||||||
ASSERT_EQ(exp2, row_s1_6);
|
ASSERT_EQ(exp2, row_s1_6);
|
||||||
|
|
||||||
delete result;
|
}
|
||||||
}
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -260,9 +260,9 @@ TEST_F(RNGTests, Test_Gaussian_21) {
|
||||||
auto result = op.evaluate({&x0}, {}, {});
|
auto result = op.evaluate({&x0}, {}, {});
|
||||||
//x0.printIndexedBuffer("X0 Normal");
|
//x0.printIndexedBuffer("X0 Normal");
|
||||||
//x1.printIndexedBuffer("X1 Normal");
|
//x1.printIndexedBuffer("X1 Normal");
|
||||||
ASSERT_TRUE(result->status() == Status::OK());
|
ASSERT_TRUE(result.status() == Status::OK());
|
||||||
auto mean = result->at(0);
|
auto mean = result.at(0);
|
||||||
auto variance = result->at(1);
|
auto variance = result.at(1);
|
||||||
|
|
||||||
// mean->printIndexedBuffer("Mean");
|
// mean->printIndexedBuffer("Mean");
|
||||||
// variance->printIndexedBuffer("Variance");
|
// variance->printIndexedBuffer("Variance");
|
||||||
|
@ -270,7 +270,7 @@ TEST_F(RNGTests, Test_Gaussian_21) {
|
||||||
ASSERT_NEAR(sd::math::nd4j_abs(mean->e<float>(0)), 0.f, 0.2f);
|
ASSERT_NEAR(sd::math::nd4j_abs(mean->e<float>(0)), 0.f, 0.2f);
|
||||||
ASSERT_NEAR(variance->e<float>(0), 1.0f, 0.2f);
|
ASSERT_NEAR(variance->e<float>(0), 1.0f, 0.2f);
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef DEBUG_BUILD
|
#ifdef DEBUG_BUILD
|
||||||
|
@ -292,15 +292,15 @@ TEST_F(RNGTests, Test_Gaussian_22) {
|
||||||
auto result = op.evaluate({&x0}, {}, {});
|
auto result = op.evaluate({&x0}, {}, {});
|
||||||
//x0.printIndexedBuffer("X0 Normal");
|
//x0.printIndexedBuffer("X0 Normal");
|
||||||
//x1.printIndexedBuffer("X1 Normal");
|
//x1.printIndexedBuffer("X1 Normal");
|
||||||
ASSERT_TRUE(result->status() == Status::OK());
|
ASSERT_TRUE(result.status() == Status::OK());
|
||||||
auto mean0 = result->at(0);
|
auto mean0 = result.at(0);
|
||||||
auto variance0 = result->at(1);
|
auto variance0 = result.at(1);
|
||||||
|
|
||||||
//mean0->printIndexedBuffer("Mean");
|
//mean0->printIndexedBuffer("Mean");
|
||||||
//variance0->printIndexedBuffer("Variance");
|
//variance0->printIndexedBuffer("Variance");
|
||||||
ASSERT_NEAR(sd::math::nd4j_abs(mean0->e<float>(0)), 0.f, 1.0e-3f);
|
ASSERT_NEAR(sd::math::nd4j_abs(mean0->e<float>(0)), 0.f, 1.0e-3f);
|
||||||
ASSERT_NEAR(variance0->e<float>(0), 1.0f, 1.e-3f);
|
ASSERT_NEAR(variance0->e<float>(0), 1.0f, 1.e-3f);
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RNGTests, Test_Gaussian_3) {
|
TEST_F(RNGTests, Test_Gaussian_3) {
|
||||||
|
@ -413,18 +413,17 @@ TEST_F(RNGTests, Test_Truncated_21) {
|
||||||
ASSERT_NEAR(deviation.e<float>(0), 2.f, 0.5);
|
ASSERT_NEAR(deviation.e<float>(0), 2.f, 0.5);
|
||||||
sd::ops::moments op;
|
sd::ops::moments op;
|
||||||
auto result = op.evaluate({&x0}, {}, {}, {}, {}, false);
|
auto result = op.evaluate({&x0}, {}, {}, {}, {}, false);
|
||||||
// result->at(0)->printBuffer("MEAN");
|
|
||||||
// result->at(1)->printBuffer("VARIANCE");
|
// result.at(0)->printBuffer("MEAN");
|
||||||
delete result;
|
// result.at(1)->printBuffer("VARIANCE");
|
||||||
|
|
||||||
sd::ops::reduce_min minOp;
|
sd::ops::reduce_min minOp;
|
||||||
sd::ops::reduce_max maxOp;
|
sd::ops::reduce_max maxOp;
|
||||||
|
|
||||||
auto minRes = minOp.evaluate({&x1}, {}, {}, {});
|
auto minRes = minOp.evaluate({&x1}, {}, {}, {});
|
||||||
auto maxRes = maxOp.evaluate({&x0}, {}, {}, {});
|
auto maxRes = maxOp.evaluate({&x0}, {}, {}, {});
|
||||||
// minRes->at(0)->printBuffer("MIN for Truncated");
|
// minRes->at(0)->printBuffer("MIN for Truncated");
|
||||||
// maxRes->at(0)->printBuffer("MAX for Truncated");
|
// maxRes->at(0)->printBuffer("MAX for Truncated");
|
||||||
|
|
||||||
delete minRes;
|
|
||||||
delete maxRes;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RNGTests, Test_Truncated_22) {
|
TEST_F(RNGTests, Test_Truncated_22) {
|
||||||
|
@ -460,18 +459,15 @@ TEST_F(RNGTests, Test_Truncated_22) {
|
||||||
ASSERT_NEAR(deviation.e<float>(0), 4.f, 0.52);
|
ASSERT_NEAR(deviation.e<float>(0), 4.f, 0.52);
|
||||||
sd::ops::moments op;
|
sd::ops::moments op;
|
||||||
auto result = op.evaluate({&x0}, {}, {}, {}, {}, false);
|
auto result = op.evaluate({&x0}, {}, {}, {}, {}, false);
|
||||||
// result->at(0)->printBuffer("MEAN");
|
|
||||||
// result->at(1)->printBuffer("VARIANCE");
|
|
||||||
delete result;
|
|
||||||
sd::ops::reduce_min minOp;
|
sd::ops::reduce_min minOp;
|
||||||
sd::ops::reduce_max maxOp;
|
sd::ops::reduce_max maxOp;
|
||||||
|
|
||||||
auto minRes = minOp.evaluate({&x1}, {}, {}, {});
|
auto minRes = minOp.evaluate({&x1}, {}, {}, {});
|
||||||
auto maxRes = maxOp.evaluate({&x0}, {}, {}, {});
|
auto maxRes = maxOp.evaluate({&x0}, {}, {}, {});
|
||||||
// minRes->at(0)->printBuffer("MIN for Truncated2");
|
// minRes->at(0)->printBuffer("MIN for Truncated2");
|
||||||
// maxRes->at(0)->printBuffer("MAX for Truncated2");
|
// maxRes->at(0)->printBuffer("MAX for Truncated2");
|
||||||
|
|
||||||
delete minRes;
|
|
||||||
delete maxRes;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RNGTests, Test_Truncated_23) {
|
TEST_F(RNGTests, Test_Truncated_23) {
|
||||||
|
@ -509,16 +505,14 @@ TEST_F(RNGTests, Test_Truncated_23) {
|
||||||
auto result = op.evaluate({&x0});
|
auto result = op.evaluate({&x0});
|
||||||
// result->at(0)->printBuffer("MEAN");
|
// result->at(0)->printBuffer("MEAN");
|
||||||
// result->at(1)->printBuffer("VARIANCE");
|
// result->at(1)->printBuffer("VARIANCE");
|
||||||
delete result;
|
|
||||||
sd::ops::reduce_min minOp;
|
sd::ops::reduce_min minOp;
|
||||||
sd::ops::reduce_max maxOp;
|
sd::ops::reduce_max maxOp;
|
||||||
|
|
||||||
auto minRes = minOp.evaluate({&x1}, {}, {}, {});
|
auto minRes = minOp.evaluate({&x1}, {}, {}, {});
|
||||||
auto maxRes = maxOp.evaluate({&x0}, {}, {}, {});
|
auto maxRes = maxOp.evaluate({&x0}, {}, {}, {});
|
||||||
// minRes->at(0)->printBuffer("MIN for Truncated3");
|
// minRes->at(0)->printBuffer("MIN for Truncated3");
|
||||||
// maxRes->at(0)->printBuffer("MAX for Truncated3");
|
// maxRes->at(0)->printBuffer("MAX for Truncated3");
|
||||||
|
|
||||||
delete minRes;
|
|
||||||
delete maxRes;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RNGTests, Test_Truncated_3) {
|
TEST_F(RNGTests, Test_Truncated_3) {
|
||||||
|
@ -568,15 +562,15 @@ TEST_F(RNGTests, Test_Uniform_2) {
|
||||||
auto op = new sd::ops::LegacyRandomOp(0);
|
auto op = new sd::ops::LegacyRandomOp(0);
|
||||||
auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {});
|
auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(x1.isSameShape(z));
|
ASSERT_TRUE(x1.isSameShape(z));
|
||||||
ASSERT_TRUE(x1.equalsTo(z));
|
ASSERT_TRUE(x1.equalsTo(z));
|
||||||
|
|
||||||
delete op;
|
delete op;
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RNGTests, Test_Gaussian_2) {
|
TEST_F(RNGTests, Test_Gaussian_2) {
|
||||||
|
@ -588,15 +582,15 @@ TEST_F(RNGTests, Test_Gaussian_2) {
|
||||||
auto op = new sd::ops::LegacyRandomOp(random::GaussianDistribution);
|
auto op = new sd::ops::LegacyRandomOp(random::GaussianDistribution);
|
||||||
auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {});
|
auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(x1.isSameShape(z));
|
ASSERT_TRUE(x1.isSameShape(z));
|
||||||
ASSERT_TRUE(x1.equalsTo(z));
|
ASSERT_TRUE(x1.equalsTo(z));
|
||||||
|
|
||||||
delete op;
|
delete op;
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RNGTests, Test_LogNorm_2) {
|
TEST_F(RNGTests, Test_LogNorm_2) {
|
||||||
|
@ -608,15 +602,15 @@ TEST_F(RNGTests, Test_LogNorm_2) {
|
||||||
auto op = new sd::ops::LegacyRandomOp(random::LogNormalDistribution);
|
auto op = new sd::ops::LegacyRandomOp(random::LogNormalDistribution);
|
||||||
auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {});
|
auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(x1.isSameShape(z));
|
ASSERT_TRUE(x1.isSameShape(z));
|
||||||
ASSERT_TRUE(x1.equalsTo(z));
|
ASSERT_TRUE(x1.equalsTo(z));
|
||||||
|
|
||||||
delete op;
|
delete op;
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RNGTests, Test_TruncatedNorm_2) {
|
TEST_F(RNGTests, Test_TruncatedNorm_2) {
|
||||||
|
@ -628,14 +622,14 @@ TEST_F(RNGTests, Test_TruncatedNorm_2) {
|
||||||
auto op = new sd::ops::LegacyRandomOp(random::TruncatedNormalDistribution);
|
auto op = new sd::ops::LegacyRandomOp(random::TruncatedNormalDistribution);
|
||||||
auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {});
|
auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(x1.isSameShape(z));
|
ASSERT_TRUE(x1.isSameShape(z));
|
||||||
ASSERT_TRUE(x1.equalsTo(z));
|
ASSERT_TRUE(x1.equalsTo(z));
|
||||||
delete op;
|
delete op;
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -648,15 +642,15 @@ TEST_F(RNGTests, Test_Binomial_2) {
|
||||||
auto op = new sd::ops::LegacyRandomOp(random::BinomialDistributionEx);
|
auto op = new sd::ops::LegacyRandomOp(random::BinomialDistributionEx);
|
||||||
auto result = op->execute(_rngA, {&input}, {0.5f}, {3});
|
auto result = op->execute(_rngA, {&input}, {0.5f}, {3});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(x1.isSameShape(z));
|
ASSERT_TRUE(x1.isSameShape(z));
|
||||||
ASSERT_TRUE(x1.equalsTo(z));
|
ASSERT_TRUE(x1.equalsTo(z));
|
||||||
|
|
||||||
delete op;
|
delete op;
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -669,15 +663,15 @@ TEST_F(RNGTests, Test_Bernoulli_2) {
|
||||||
auto op = new sd::ops::LegacyRandomOp(random::BernoulliDistribution);
|
auto op = new sd::ops::LegacyRandomOp(random::BernoulliDistribution);
|
||||||
auto result = op->execute(_rngA, {&input}, {0.5f}, {});
|
auto result = op->execute(_rngA, {&input}, {0.5f}, {});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(x1.isSameShape(z));
|
ASSERT_TRUE(x1.isSameShape(z));
|
||||||
ASSERT_TRUE(x1.equalsTo(z));
|
ASSERT_TRUE(x1.equalsTo(z));
|
||||||
|
|
||||||
delete op;
|
delete op;
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RNGTests, Test_GaussianDistribution_1) {
|
TEST_F(RNGTests, Test_GaussianDistribution_1) {
|
||||||
|
@ -687,9 +681,9 @@ TEST_F(RNGTests, Test_GaussianDistribution_1) {
|
||||||
|
|
||||||
sd::ops::random_normal op;
|
sd::ops::random_normal op;
|
||||||
auto result = op.evaluate({&x}, {0.0, 1.0f}, {});
|
auto result = op.evaluate({&x}, {0.0, 1.0f}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
ASSERT_TRUE(exp0.isSameShape(z));
|
ASSERT_TRUE(exp0.isSameShape(z));
|
||||||
ASSERT_FALSE(exp0.equalsTo(z));
|
ASSERT_FALSE(exp0.equalsTo(z));
|
||||||
|
|
||||||
|
@ -698,7 +692,7 @@ TEST_F(RNGTests, Test_GaussianDistribution_1) {
|
||||||
ASSERT_FALSE(nexp1->equalsTo(z));
|
ASSERT_FALSE(nexp1->equalsTo(z));
|
||||||
ASSERT_FALSE(nexp2->equalsTo(z));
|
ASSERT_FALSE(nexp2->equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RNGTests, Test_BernoulliDistribution_1) {
|
TEST_F(RNGTests, Test_BernoulliDistribution_1) {
|
||||||
|
@ -708,9 +702,9 @@ TEST_F(RNGTests, Test_BernoulliDistribution_1) {
|
||||||
|
|
||||||
sd::ops::random_bernoulli op;
|
sd::ops::random_bernoulli op;
|
||||||
auto result = op.evaluate({&x}, {0.5f}, {});
|
auto result = op.evaluate({&x}, {0.5f}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_FALSE(exp0.equalsTo(z));
|
ASSERT_FALSE(exp0.equalsTo(z));
|
||||||
|
|
||||||
|
@ -718,7 +712,7 @@ TEST_F(RNGTests, Test_BernoulliDistribution_1) {
|
||||||
ASSERT_FALSE(nexp1->equalsTo(z));
|
ASSERT_FALSE(nexp1->equalsTo(z));
|
||||||
ASSERT_FALSE(nexp2->equalsTo(z));
|
ASSERT_FALSE(nexp2->equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -729,9 +723,9 @@ TEST_F(RNGTests, Test_ExponentialDistribution_1) {
|
||||||
|
|
||||||
sd::ops::random_exponential op;
|
sd::ops::random_exponential op;
|
||||||
auto result = op.evaluate({&x}, {0.25f}, {0});
|
auto result = op.evaluate({&x}, {0.25f}, {0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
ASSERT_TRUE(exp0.isSameShape(z));
|
ASSERT_TRUE(exp0.isSameShape(z));
|
||||||
ASSERT_FALSE(exp0.equalsTo(z));
|
ASSERT_FALSE(exp0.equalsTo(z));
|
||||||
|
|
||||||
|
@ -740,7 +734,7 @@ TEST_F(RNGTests, Test_ExponentialDistribution_1) {
|
||||||
ASSERT_FALSE(nexp1->equalsTo(z));
|
ASSERT_FALSE(nexp1->equalsTo(z));
|
||||||
ASSERT_FALSE(nexp2->equalsTo(z));
|
ASSERT_FALSE(nexp2->equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RNGTests, Test_ExponentialDistribution_2) {
|
TEST_F(RNGTests, Test_ExponentialDistribution_2) {
|
||||||
|
@ -753,9 +747,9 @@ TEST_F(RNGTests, Test_ExponentialDistribution_2) {
|
||||||
|
|
||||||
sd::ops::random_exponential op;
|
sd::ops::random_exponential op;
|
||||||
auto result = op.evaluate({&x, &y}, {0.25f}, {0});
|
auto result = op.evaluate({&x, &y}, {0.25f}, {0});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
ASSERT_TRUE(exp0.isSameShape(z));
|
ASSERT_TRUE(exp0.isSameShape(z));
|
||||||
ASSERT_FALSE(exp0.equalsTo(z));
|
ASSERT_FALSE(exp0.equalsTo(z));
|
||||||
|
|
||||||
|
@ -764,7 +758,7 @@ TEST_F(RNGTests, Test_ExponentialDistribution_2) {
|
||||||
ASSERT_FALSE(nexp1->equalsTo(z));
|
ASSERT_FALSE(nexp1->equalsTo(z));
|
||||||
ASSERT_FALSE(nexp2->equalsTo(z));
|
ASSERT_FALSE(nexp2->equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RNGTests, Test_PoissonDistribution_1) {
|
TEST_F(RNGTests, Test_PoissonDistribution_1) {
|
||||||
|
@ -777,14 +771,14 @@ TEST_F(RNGTests, Test_PoissonDistribution_1) {
|
||||||
|
|
||||||
sd::ops::random_poisson op;
|
sd::ops::random_poisson op;
|
||||||
auto result = op.evaluate({&x, &la}, {}, {});
|
auto result = op.evaluate({&x, &la}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
// z->printIndexedBuffer("Poisson distribution");
|
// z->printIndexedBuffer("Poisson distribution");
|
||||||
ASSERT_TRUE(exp0.isSameShape(z));
|
ASSERT_TRUE(exp0.isSameShape(z));
|
||||||
ASSERT_FALSE(exp0.equalsTo(z));
|
ASSERT_FALSE(exp0.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RNGTests, Test_GammaDistribution_1) {
|
TEST_F(RNGTests, Test_GammaDistribution_1) {
|
||||||
|
@ -797,14 +791,14 @@ TEST_F(RNGTests, Test_GammaDistribution_1) {
|
||||||
|
|
||||||
sd::ops::random_gamma op;
|
sd::ops::random_gamma op;
|
||||||
auto result = op.evaluate({&x, &al}, {}, {});
|
auto result = op.evaluate({&x, &al}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
// z->printIndexedBuffer("Gamma distribution");
|
// z->printIndexedBuffer("Gamma distribution");
|
||||||
ASSERT_TRUE(exp0.isSameShape(z));
|
ASSERT_TRUE(exp0.isSameShape(z));
|
||||||
ASSERT_FALSE(exp0.equalsTo(z));
|
ASSERT_FALSE(exp0.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RNGTests, Test_GammaDistribution_2) {
|
TEST_F(RNGTests, Test_GammaDistribution_2) {
|
||||||
|
@ -818,14 +812,14 @@ TEST_F(RNGTests, Test_GammaDistribution_2) {
|
||||||
|
|
||||||
sd::ops::random_gamma op;
|
sd::ops::random_gamma op;
|
||||||
auto result = op.evaluate({&x, &al, &be}, {}, {});
|
auto result = op.evaluate({&x, &al, &be}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
// z->printIndexedBuffer("Gamma distribution");
|
// z->printIndexedBuffer("Gamma distribution");
|
||||||
ASSERT_TRUE(exp0.isSameShape(z));
|
ASSERT_TRUE(exp0.isSameShape(z));
|
||||||
ASSERT_FALSE(exp0.equalsTo(z));
|
ASSERT_FALSE(exp0.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RNGTests, Test_GammaDistribution_3) {
|
TEST_F(RNGTests, Test_GammaDistribution_3) {
|
||||||
|
@ -839,14 +833,14 @@ TEST_F(RNGTests, Test_GammaDistribution_3) {
|
||||||
|
|
||||||
sd::ops::random_gamma op;
|
sd::ops::random_gamma op;
|
||||||
auto result = op.evaluate({&x, &al, &be}, {}, {});
|
auto result = op.evaluate({&x, &al, &be}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
// z->printIndexedBuffer("Gamma distribution");
|
// z->printIndexedBuffer("Gamma distribution");
|
||||||
ASSERT_TRUE(exp0.isSameShape(z));
|
ASSERT_TRUE(exp0.isSameShape(z));
|
||||||
ASSERT_FALSE(exp0.equalsTo(z));
|
ASSERT_FALSE(exp0.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RNGTests, Test_UniformDistribution_04) {
|
TEST_F(RNGTests, Test_UniformDistribution_04) {
|
||||||
|
@ -858,13 +852,13 @@ TEST_F(RNGTests, Test_UniformDistribution_04) {
|
||||||
|
|
||||||
sd::ops::randomuniform op;
|
sd::ops::randomuniform op;
|
||||||
auto result = op.evaluate({&x, &al, &be}, {}, {DataType::INT32});
|
auto result = op.evaluate({&x, &al, &be}, {}, {DataType::INT32});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
ASSERT_TRUE(exp0.isSameShape(z));
|
ASSERT_TRUE(exp0.isSameShape(z));
|
||||||
ASSERT_FALSE(exp0.equalsTo(z));
|
ASSERT_FALSE(exp0.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace sd {
|
namespace sd {
|
||||||
|
@ -1021,12 +1015,12 @@ TEST_F(RNGTests, test_multinomial_1) {
|
||||||
NDArray expectedZ('c', { 3, 3 }, { 0., 0, 0, 0, 0, 0, 0, 0, 0 }, sd::DataType::INT64);
|
NDArray expectedZ('c', { 3, 3 }, { 0., 0, 0, 0, 0, 0, 0, 0, 0 }, sd::DataType::INT64);
|
||||||
|
|
||||||
auto result = op.evaluate({ &probsZ, &samples }, { }, { 1, INT64 });
|
auto result = op.evaluate({ &probsZ, &samples }, { }, { 1, INT64 });
|
||||||
auto outputZ = result->at(0);
|
auto outputZ = result.at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
ASSERT_TRUE(expectedZ.isSameShape(outputZ));
|
ASSERT_TRUE(expectedZ.isSameShape(outputZ));
|
||||||
ASSERT_TRUE(expectedZ.equalsTo(outputZ));
|
ASSERT_TRUE(expectedZ.equalsTo(outputZ));
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RNGTests, test_multinomial_2) {
|
TEST_F(RNGTests, test_multinomial_2) {
|
||||||
|
@ -1117,8 +1111,8 @@ TEST_F(RNGTests, test_multinomial_5) {
|
||||||
}
|
}
|
||||||
|
|
||||||
auto resultR = op.evaluate({ &probs, &samples }, { }, { 1 });
|
auto resultR = op.evaluate({ &probs, &samples }, { }, { 1 });
|
||||||
auto outputR = resultR->at(0);
|
auto outputR = resultR.at(0);
|
||||||
ASSERT_EQ(Status::OK(), resultR->status());
|
ASSERT_EQ(Status::OK(), resultR.status());
|
||||||
|
|
||||||
deviation = outputR->varianceNumber(variance::SummaryStatsStandardDeviation, false);
|
deviation = outputR->varianceNumber(variance::SummaryStatsStandardDeviation, false);
|
||||||
mean = outputR->meanNumber();
|
mean = outputR->meanNumber();
|
||||||
|
@ -1131,7 +1125,6 @@ TEST_F(RNGTests, test_multinomial_5) {
|
||||||
ASSERT_TRUE(value >= 0 && value < ClassValue);
|
ASSERT_TRUE(value >= 0 && value < ClassValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
delete resultR;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1150,8 +1143,8 @@ TEST_F(RNGTests, test_multinomial_6) {
|
||||||
NDArray probsR('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, sd::DataType::FLOAT32);
|
NDArray probsR('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, sd::DataType::FLOAT32);
|
||||||
|
|
||||||
auto resultR = op.evaluate({ &probsR, &samples }, { }, { 0 });
|
auto resultR = op.evaluate({ &probsR, &samples }, { }, { 0 });
|
||||||
auto outputR = resultR->at(0);
|
auto outputR = resultR.at(0);
|
||||||
ASSERT_EQ(Status::OK(), resultR->status());
|
ASSERT_EQ(Status::OK(), resultR.status());
|
||||||
|
|
||||||
NDArray countsR('c', { ClassValue }, { 0., 0, 0, 0, 0 }, sd::DataType::DOUBLE);
|
NDArray countsR('c', { ClassValue }, { 0., 0, 0, 0, 0 }, sd::DataType::DOUBLE);
|
||||||
|
|
||||||
|
@ -1175,7 +1168,7 @@ TEST_F(RNGTests, test_multinomial_6) {
|
||||||
ASSERT_NEAR(1.2175, deviation.e<double>(0), 45e-3); // 1000000 35e-3);
|
ASSERT_NEAR(1.2175, deviation.e<double>(0), 45e-3); // 1000000 35e-3);
|
||||||
ASSERT_NEAR(2.906, mean.e<double>(0), 45e-3); // 1000000 35e-3);
|
ASSERT_NEAR(2.906, mean.e<double>(0), 45e-3); // 1000000 35e-3);
|
||||||
|
|
||||||
delete resultR;
|
|
||||||
|
|
||||||
RandomGenerator rng(1234, 1234);
|
RandomGenerator rng(1234, 1234);
|
||||||
NDArray probs('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, sd::DataType::FLOAT32);
|
NDArray probs('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, sd::DataType::FLOAT32);
|
||||||
|
|
|
@ -96,14 +96,14 @@ TEST_F(ScalarTests, Test_Concat_1) {
|
||||||
sd::ops::concat op;
|
sd::ops::concat op;
|
||||||
auto result = op.evaluate({&t, &u, &v}, {}, {0});
|
auto result = op.evaluate({&t, &u, &v}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -116,15 +116,15 @@ TEST_F(ScalarTests, Test_Concat_2) {
|
||||||
sd::ops::concat op;
|
sd::ops::concat op;
|
||||||
auto result = op.evaluate({&t, &u, &v}, {}, {0});
|
auto result = op.evaluate({&t, &u, &v}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
// z->printIndexedBuffer();
|
// z->printIndexedBuffer();
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -137,16 +137,16 @@ TEST_F(ScalarTests, Test_Concat_3) {
|
||||||
sd::ops::concat op;
|
sd::ops::concat op;
|
||||||
auto result = op.evaluate({&t, &u, &v}, {}, {0});
|
auto result = op.evaluate({&t, &u, &v}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
//z->printShapeInfo("z");
|
//z->printShapeInfo("z");
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ScalarTests, Test_ExpandDims_1) {
|
TEST_F(ScalarTests, Test_ExpandDims_1) {
|
||||||
|
@ -156,14 +156,14 @@ TEST_F(ScalarTests, Test_ExpandDims_1) {
|
||||||
sd::ops::expand_dims op;
|
sd::ops::expand_dims op;
|
||||||
auto result = op.evaluate({&x}, {}, {0});
|
auto result = op.evaluate({&x}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ScalarTests, Test_Squeeze_1) {
|
TEST_F(ScalarTests, Test_Squeeze_1) {
|
||||||
|
@ -172,14 +172,14 @@ TEST_F(ScalarTests, Test_Squeeze_1) {
|
||||||
|
|
||||||
sd::ops::squeeze op;
|
sd::ops::squeeze op;
|
||||||
auto result = op.evaluate({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -189,14 +189,14 @@ TEST_F(ScalarTests, Test_Reshape_1) {
|
||||||
|
|
||||||
sd::ops::reshape op;
|
sd::ops::reshape op;
|
||||||
auto result = op.evaluate({&x}, {}, {-99, 1, 1, 1});
|
auto result = op.evaluate({&x}, {}, {-99, 1, 1, 1});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -206,14 +206,14 @@ TEST_F(ScalarTests, Test_Permute_1) {
|
||||||
|
|
||||||
sd::ops::permute op;
|
sd::ops::permute op;
|
||||||
auto result = op.evaluate({&x}, {}, {0});
|
auto result = op.evaluate({&x}, {}, {0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ScalarTests, Test_Concat_Scalar_1) {
|
TEST_F(ScalarTests, Test_Concat_Scalar_1) {
|
||||||
|
@ -225,14 +225,13 @@ TEST_F(ScalarTests, Test_Concat_Scalar_1) {
|
||||||
|
|
||||||
sd::ops::concat op;
|
sd::ops::concat op;
|
||||||
auto result = op.evaluate({&t, &u, &v, &w}, {}, {0});
|
auto result = op.evaluate({&t, &u, &v, &w}, {}, {0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -245,12 +244,11 @@ TEST_F(ScalarTests, Test_Concat_Scalar_2) {
|
||||||
|
|
||||||
sd::ops::concat op;
|
sd::ops::concat op;
|
||||||
auto result = op.evaluate({&t, &u, &v, &w}, {}, {1});
|
auto result = op.evaluate({&t, &u, &v, &w}, {}, {1});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
|
@ -308,14 +308,13 @@ TEST_F(ShapeTests, Tests_Transpose_119_2) {
|
||||||
|
|
||||||
sd::ops::transpose op;
|
sd::ops::transpose op;
|
||||||
auto result = op.evaluate({&x});
|
auto result = op.evaluate({&x});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ShapeTests, Tests_Transpose_119_3) {
|
TEST_F(ShapeTests, Tests_Transpose_119_3) {
|
||||||
|
|
|
@ -70,14 +70,14 @@ TEST_F(SingleDimTests, Test_Concat_1) {
|
||||||
sd::ops::concat op;
|
sd::ops::concat op;
|
||||||
auto result = op.evaluate({&x, &y}, {}, {0});
|
auto result = op.evaluate({&x, &y}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(SingleDimTests, Test_Reduce_1) {
|
TEST_F(SingleDimTests, Test_Reduce_1) {
|
||||||
|
@ -104,14 +104,14 @@ TEST_F(SingleDimTests, Test_ExpandDims_1) {
|
||||||
sd::ops::expand_dims op;
|
sd::ops::expand_dims op;
|
||||||
auto result = op.evaluate({&x}, {}, {0});
|
auto result = op.evaluate({&x}, {}, {0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -122,14 +122,14 @@ TEST_F(SingleDimTests, Test_ExpandDims_2) {
|
||||||
sd::ops::expand_dims op;
|
sd::ops::expand_dims op;
|
||||||
auto result = op.evaluate({&x}, {}, {1});
|
auto result = op.evaluate({&x}, {}, {1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -142,14 +142,14 @@ TEST_F(SingleDimTests, Test_Squeeze_1) {
|
||||||
sd::ops::squeeze op;
|
sd::ops::squeeze op;
|
||||||
auto result = op.evaluate({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_EQ(exp.rankOf(), z->rankOf());
|
ASSERT_EQ(exp.rankOf(), z->rankOf());
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(SingleDimTests, Test_Squeeze_2) {
|
TEST_F(SingleDimTests, Test_Squeeze_2) {
|
||||||
|
@ -158,14 +158,14 @@ TEST_F(SingleDimTests, Test_Squeeze_2) {
|
||||||
|
|
||||||
sd::ops::squeeze op;
|
sd::ops::squeeze op;
|
||||||
auto result = op.evaluate({&x}, {}, {});
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(SingleDimTests, Test_Reshape_1) {
|
TEST_F(SingleDimTests, Test_Reshape_1) {
|
||||||
|
@ -174,14 +174,14 @@ TEST_F(SingleDimTests, Test_Reshape_1) {
|
||||||
|
|
||||||
sd::ops::reshape op;
|
sd::ops::reshape op;
|
||||||
auto result = op.evaluate({&x}, {}, {-99, 3});
|
auto result = op.evaluate({&x}, {}, {-99, 3});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(SingleDimTests, Test_Reshape_2) {
|
TEST_F(SingleDimTests, Test_Reshape_2) {
|
||||||
|
@ -190,14 +190,14 @@ TEST_F(SingleDimTests, Test_Reshape_2) {
|
||||||
|
|
||||||
sd::ops::reshape op;
|
sd::ops::reshape op;
|
||||||
auto result = op.evaluate({&x}, {}, {-99, 1, 3});
|
auto result = op.evaluate({&x}, {}, {-99, 1, 3});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -207,12 +207,12 @@ TEST_F(SingleDimTests, Test_Permute_1) {
|
||||||
|
|
||||||
sd::ops::permute op;
|
sd::ops::permute op;
|
||||||
auto result = op.evaluate({&x}, {}, {0});
|
auto result = op.evaluate({&x}, {}, {0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
}
|
Loading…
Reference in New Issue