Shugeo cuda tests (#116)
* Added tests for get_seed/set_seed ops. * Added missed tests for scatter_sub/mul/div ops. * Added tests for hardsigmoid and hardsigmoid_bp. * Added tests for hardtanh and hardtanh_bp ops. * Added test for histogram op. * Added tests for identity op. * Refactored mergemaxindex op. Added tests for log1p,mergemaxindex, mod and mod_bp ops. * Fixed tests for FloorDiv. * Added test for rank op. * Added tests for rationaltanh/rationaltanh_bp ops. * Added tests for realdiv/realdiv_bp. * Added tests for rectifiedtanh/_bp ops. * Added tests for shapes_of op. * Added tests for shapes_of op. * Added tests for size op. * Added tests for softplus/_bp ops. * Added tests for softsign/_bp ops. * Added tests for toggle_bits op. Fixed processing of OP_IMPL and so on defititions. * Added test for truncatediv op. * Added another test for truncatediv op. * Added another test for histogram. * Added tests for unstack_list op. * Refactored to_int32/uint32/float16/float32/double/int64/uint64 ops and tests. * Refactored mergemaxindex op helper for cuda platform and tests. * Fixed cuda kernel for histogram op helper. * Refactor skipgram to avoid early buffers shift. * Fixed check up with non_max_suppression op cuda helper. Added cuda kernel implementation for skipgram op helpers. * Added implementation of skipgram op helper for cuda platform. Working revision * Fixed mergeMaxIndex kernel and move it to separate source file.master
parent
6264530dd8
commit
f083b96c74
|
@ -172,7 +172,9 @@ template void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector<int8
|
||||||
template NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<float16>& data, nd4j::LaunchContext * context);
|
template NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<float16>& data, nd4j::LaunchContext * context);
|
||||||
template NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<bfloat16>& data, nd4j::LaunchContext * context);
|
template NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<bfloat16>& data, nd4j::LaunchContext * context);
|
||||||
template NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<Nd4jLong>& data, nd4j::LaunchContext * context);
|
template NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<Nd4jLong>& data, nd4j::LaunchContext * context);
|
||||||
|
template NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<uint64_t>& data, nd4j::LaunchContext * context);
|
||||||
template NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<int>& data, nd4j::LaunchContext * context);
|
template NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<int>& data, nd4j::LaunchContext * context);
|
||||||
|
template NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<unsigned int>& data, nd4j::LaunchContext * context);
|
||||||
template NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<int16_t>& data, nd4j::LaunchContext * context);
|
template NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<int16_t>& data, nd4j::LaunchContext * context);
|
||||||
template NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<int8_t>& data, nd4j::LaunchContext * context);
|
template NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<int8_t>& data, nd4j::LaunchContext * context);
|
||||||
template NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<uint8_t>& data, nd4j::LaunchContext * context);
|
template NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<uint8_t>& data, nd4j::LaunchContext * context);
|
||||||
|
|
|
@ -137,8 +137,8 @@ namespace nd4j {
|
||||||
auto newAxis = ShapeUtils::evalDimsToExclude(array->rankOf(), args);
|
auto newAxis = ShapeUtils::evalDimsToExclude(array->rankOf(), args);
|
||||||
auto result = array->allTensorsAlongDimension(newAxis);
|
auto result = array->allTensorsAlongDimension(newAxis);
|
||||||
for (int e = 0; e < result->size(); e++) {
|
for (int e = 0; e < result->size(); e++) {
|
||||||
auto chunk = result->at(e)->dup(array->ordering());
|
auto chunk = result->at(e);//->dup(array->ordering());
|
||||||
write(e, chunk);
|
write(e, chunk->dup(array->ordering()));
|
||||||
}
|
}
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1328,7 +1328,7 @@
|
||||||
REGISTER_C(NAME) \
|
REGISTER_C(NAME) \
|
||||||
nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) { \
|
nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) { \
|
||||||
auto shapeList = SHAPELIST(); \
|
auto shapeList = SHAPELIST(); \
|
||||||
for (int e = 0; e < this->getOpDescriptor()->getNumberOfOutputs(); e++) { \
|
for (int e = 0; e < block.width(); e++) { \
|
||||||
auto newshape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShape->at(e)), shape::order(inputShape->at(e)), shape::rank(inputShape->at(e)), shape::shapeOf(inputShape->at(e))); \
|
auto newshape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShape->at(e)), shape::order(inputShape->at(e)), shape::rank(inputShape->at(e)), shape::shapeOf(inputShape->at(e))); \
|
||||||
shapeList->push_back(newshape); \
|
shapeList->push_back(newshape); \
|
||||||
} \
|
} \
|
||||||
|
@ -1365,7 +1365,7 @@
|
||||||
REGISTER_C(NAME) \
|
REGISTER_C(NAME) \
|
||||||
nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) { \
|
nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) { \
|
||||||
auto shapeList = SHAPELIST(); \
|
auto shapeList = SHAPELIST(); \
|
||||||
for (int e = 0; e < this->getOpDescriptor()->getNumberOfOutputs(); e++) { \
|
for (int e = 0; e < block.width(); e++) { \
|
||||||
Nd4jLong* newshape; \
|
Nd4jLong* newshape; \
|
||||||
COPY_SHAPE(inputShape->at(0), newshape); \
|
COPY_SHAPE(inputShape->at(0), newshape); \
|
||||||
shapeList->push_back(CONSTANT(newshape)); \
|
shapeList->push_back(CONSTANT(newshape)); \
|
||||||
|
@ -1388,7 +1388,7 @@
|
||||||
REGISTER_C(NAME) \
|
REGISTER_C(NAME) \
|
||||||
nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) { \
|
nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) { \
|
||||||
auto shapeList = SHAPELIST(); \
|
auto shapeList = SHAPELIST(); \
|
||||||
for (int e = 0; e < this->getOpDescriptor()->getNumberOfOutputs(); e++) { \
|
for (int e = 0; e < block.width(); e++) { \
|
||||||
auto newshape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShape->at(e)), shape::order(inputShape->at(e)), shape::rank(inputShape->at(e)), shape::shapeOf(inputShape->at(e))); \
|
auto newshape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShape->at(e)), shape::order(inputShape->at(e)), shape::rank(inputShape->at(e)), shape::shapeOf(inputShape->at(e))); \
|
||||||
shapeList->push_back(newshape); \
|
shapeList->push_back(newshape); \
|
||||||
} \
|
} \
|
||||||
|
|
|
@ -27,14 +27,14 @@
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
OP_IMPL(toggle_bits, -1, -1, true) {
|
OP_IMPL(toggle_bits, -1, 1, true) {
|
||||||
|
|
||||||
for (int i = 0; i < block.width(); i++) {
|
for (int i = 0; i < block.width(); i++) {
|
||||||
auto x = INPUT_VARIABLE(i);
|
auto x = INPUT_VARIABLE(i);
|
||||||
auto z = OUTPUT_VARIABLE(i);
|
auto z = OUTPUT_VARIABLE(i);
|
||||||
|
|
||||||
REQUIRE_TRUE(x->dataType() == z->dataType(), 0, "Toggle bits requires input and output to have same type");
|
REQUIRE_TRUE(x->dataType() == z->dataType(), 0, "Toggle bits requires input and output to have same type");
|
||||||
REQUIRE_TRUE(x->isR(),0, "Toggle bits requires input and output to be integer type (int8, int16, int32, int64)");
|
REQUIRE_TRUE(x->isZ(),0, "Toggle bits requires input and output to be integer type (int8, int16, int32, int64)");
|
||||||
|
|
||||||
helpers::__toggle_bits(block.launchContext(), *x, *z);
|
helpers::__toggle_bits(block.launchContext(), *x, *z);
|
||||||
}
|
}
|
||||||
|
@ -44,7 +44,8 @@ namespace nd4j {
|
||||||
DECLARE_TYPES(toggle_bits) {
|
DECLARE_TYPES(toggle_bits) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes({ALL_INTS})
|
->setAllowedInputTypes({ALL_INTS})
|
||||||
->setSameMode(true);
|
->setAllowedOutputTypes({ALL_INTS})
|
||||||
|
->setSameMode(false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,7 +25,7 @@
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
OP_IMPL(to_double, 1, 1, true) {
|
CUSTOM_OP_IMPL(to_double, 1, 1, true, 0, 0) {
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
@ -42,6 +42,12 @@ namespace nd4j {
|
||||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||||
->setAllowedOutputTypes(nd4j::DataType::DOUBLE);
|
->setAllowedOutputTypes(nd4j::DataType::DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DECLARE_SHAPE_FN(to_double) {
|
||||||
|
auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::DOUBLE, true, block.workspace());
|
||||||
|
return SHAPELIST(outShape);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
OP_IMPL(to_float16, 1, 1, true) {
|
CUSTOM_OP_IMPL(to_float16, 1, 1, true, 0, 0) {
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
@ -42,6 +42,12 @@ namespace nd4j {
|
||||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||||
->setAllowedOutputTypes(nd4j::DataType::HALF);
|
->setAllowedOutputTypes(nd4j::DataType::HALF);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DECLARE_SHAPE_FN(to_float16) {
|
||||||
|
auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::HALF, true, block.workspace());
|
||||||
|
return SHAPELIST(outShape);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
OP_IMPL(to_float32, 1, 1, true) {
|
CUSTOM_OP_IMPL(to_float32, 1, 1, true, 0, 0) {
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
@ -42,6 +42,12 @@ namespace nd4j {
|
||||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||||
->setAllowedOutputTypes(nd4j::DataType::FLOAT32);
|
->setAllowedOutputTypes(nd4j::DataType::FLOAT32);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DECLARE_SHAPE_FN(to_float32) {
|
||||||
|
auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::FLOAT32, true, block.workspace());
|
||||||
|
return SHAPELIST(outShape);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
OP_IMPL(to_int32, 1, 1, true) {
|
CUSTOM_OP_IMPL(to_int32, 1, 1, true, 0, 0) {
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
@ -42,6 +42,11 @@ namespace nd4j {
|
||||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||||
->setAllowedOutputTypes(nd4j::DataType::INT32);
|
->setAllowedOutputTypes(nd4j::DataType::INT32);
|
||||||
}
|
}
|
||||||
|
DECLARE_SHAPE_FN(to_int32) {
|
||||||
|
auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::INT32, true, block.workspace());
|
||||||
|
return SHAPELIST(outShape);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
OP_IMPL(to_int64, 1, 1, true) {
|
CUSTOM_OP_IMPL(to_int64, 1, 1, true, 0, 0) {
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
@ -42,6 +42,11 @@ namespace nd4j {
|
||||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||||
->setAllowedOutputTypes(nd4j::DataType::INT64);
|
->setAllowedOutputTypes(nd4j::DataType::INT64);
|
||||||
}
|
}
|
||||||
|
DECLARE_SHAPE_FN(to_int64) {
|
||||||
|
auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::INT64, true, block.workspace());
|
||||||
|
return SHAPELIST(outShape);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
OP_IMPL(to_uint32, 1, 1, true) {
|
CUSTOM_OP_IMPL(to_uint32, 1, 1, true, 0, 0) {
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
@ -40,8 +40,13 @@ namespace nd4j {
|
||||||
DECLARE_TYPES(to_uint32) {
|
DECLARE_TYPES(to_uint32) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||||
->setAllowedOutputTypes(nd4j::DataType::INT16);
|
->setAllowedOutputTypes(nd4j::DataType::INT32);
|
||||||
}
|
}
|
||||||
|
DECLARE_SHAPE_FN(to_uint32) {
|
||||||
|
auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::UINT32, true, block.workspace());
|
||||||
|
return SHAPELIST(outShape);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
OP_IMPL(to_uint64, 1, 1, true) {
|
CUSTOM_OP_IMPL(to_uint64, 1, 1, true, 0, 0) {
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
@ -42,6 +42,10 @@ namespace nd4j {
|
||||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||||
->setAllowedOutputTypes(nd4j::DataType::INT8);
|
->setAllowedOutputTypes(nd4j::DataType::INT8);
|
||||||
}
|
}
|
||||||
|
DECLARE_SHAPE_FN(to_uint64) {
|
||||||
|
auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::UINT64, true, block.workspace());
|
||||||
|
return SHAPELIST(outShape);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -26,13 +26,19 @@
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
LIST_OP_IMPL(unstack_list, 1, 1, 0, 0) {
|
LIST_OP_IMPL(unstack_list, 1, 1, 0, 0) {
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto outputList = INPUT_LIST(0);
|
||||||
|
auto input = INPUT_VARIABLE(int(outputList != nullptr) );
|
||||||
|
|
||||||
auto list = new NDArrayList(0, true);
|
if (outputList == nullptr) {
|
||||||
list->unstack(input, 0);
|
outputList = new NDArrayList(0, true);
|
||||||
|
//block.trackList(outputList);
|
||||||
|
setupResultList(outputList, block);
|
||||||
|
}
|
||||||
|
outputList->unstack(input, INT_ARG(0));
|
||||||
|
|
||||||
//OVERWRITE_RESULT(list);
|
//OVERWRITE_RESULT(list);
|
||||||
setupResultList(list, block);
|
|
||||||
|
//
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,11 +26,11 @@
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
CUSTOM_OP_IMPL(get_seed, -2, 1, false, 0, 0) {
|
CUSTOM_OP_IMPL(get_seed, -2, 1, false, 0, 0) {
|
||||||
REQUIRE_TRUE(block.getRNG() != nullptr, 0, "RNG should be defined in Graph");
|
// REQUIRE_TRUE(block.getRNG() != nullptr, 0, "RNG should be defined in Graph");
|
||||||
auto rng = block.getRNG();
|
auto rng = block.getRng();
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
z->p(Nd4jLong(0), rng->getSeed());
|
z->p(Nd4jLong(0), rng.rootState());
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,8 +27,9 @@
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
CUSTOM_OP_IMPL(set_seed, -2, 1, false, 0, -2) {
|
CUSTOM_OP_IMPL(set_seed, -2, 1, false, 0, -2) {
|
||||||
REQUIRE_TRUE(block.getRNG() != nullptr, 0, "RNG should be defined in Graph");
|
// REQUIRE_TRUE(block.getRNG() != nullptr, 0, "RNG should be defined in Graph");
|
||||||
auto rng = block.getRNG();
|
auto rng = block.getRng(); //.getRNG();
|
||||||
|
|
||||||
Nd4jLong seed = 0;
|
Nd4jLong seed = 0;
|
||||||
if (block.getIArguments()->size() > 0) {
|
if (block.getIArguments()->size() > 0) {
|
||||||
seed = INT_ARG(0);
|
seed = INT_ARG(0);
|
||||||
|
@ -41,8 +42,8 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME: this approach isn't really good for cuda, since it'll assume that CUDA might get nullptr instead of stream
|
// FIXME: this approach isn't really good for cuda, since it'll assume that CUDA might get nullptr instead of stream
|
||||||
refreshBuffer(nullptr, seed, (Nd4jPointer) rng);
|
//refreshBuffer(nullptr, seed, (Nd4jPointer) rng);
|
||||||
|
rng.setSeed((int)seed);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
OP_IMPL(Log1p, 2, 1, true) {
|
OP_IMPL(Log1p, 1, 1, true) {
|
||||||
auto x = INPUT_VARIABLE(0);
|
auto x = INPUT_VARIABLE(0);
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
|
|
@ -27,7 +27,7 @@
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
OP_IMPL(mergemaxindex, -1, 1, false) {
|
CUSTOM_OP_IMPL(mergemaxindex, -1, 1, false, 0, 0) {
|
||||||
|
|
||||||
REQUIRE_OK(this->validateInputDimensionsMatch(block));
|
REQUIRE_OK(this->validateInputDimensionsMatch(block));
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
@ -49,6 +49,15 @@ DECLARE_SYN(MergeMaxIndex, mergemaxindex);
|
||||||
->setAllowedInputTypes({ALL_INTS, ALL_FLOATS});
|
->setAllowedInputTypes({ALL_INTS, ALL_FLOATS});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
DECLARE_SHAPE_FN(mergemaxindex) {
|
||||||
|
auto in = inputShape->at(0);
|
||||||
|
auto dtype = DataType::INT32;
|
||||||
|
if (block.getIArguments()->size()> 0)
|
||||||
|
dtype = (DataType)INT_ARG(0);
|
||||||
|
|
||||||
|
auto resShape = ShapeBuilders::copyShapeInfoAndType(in, dtype, block.workspace());
|
||||||
|
return SHAPELIST(resShape);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
|
@ -30,7 +30,7 @@ namespace nd4j {
|
||||||
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
|
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_to_double)
|
#if NOT_EXCLUDED(OP_to_double)
|
||||||
DECLARE_OP(to_double, 1, 1, true);
|
DECLARE_CUSTOM_OP(to_double, 1, 1, true, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -39,7 +39,7 @@ namespace nd4j {
|
||||||
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
|
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_to_float16)
|
#if NOT_EXCLUDED(OP_to_float16)
|
||||||
DECLARE_OP(to_float16, 1, 1, true);
|
DECLARE_CUSTOM_OP(to_float16, 1, 1, true, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -48,7 +48,7 @@ namespace nd4j {
|
||||||
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
|
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_to_float32)
|
#if NOT_EXCLUDED(OP_to_float32)
|
||||||
DECLARE_OP(to_float32, 1, 1, true);
|
DECLARE_CUSTOM_OP(to_float32, 1, 1, true, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -57,7 +57,7 @@ namespace nd4j {
|
||||||
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
|
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_to_int32)
|
#if NOT_EXCLUDED(OP_to_int32)
|
||||||
DECLARE_OP(to_int32, 1, 1, true);
|
DECLARE_CUSTOM_OP(to_int32, 1, 1, true, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -66,7 +66,7 @@ namespace nd4j {
|
||||||
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
|
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_to_int64)
|
#if NOT_EXCLUDED(OP_to_int64)
|
||||||
DECLARE_OP(to_int64, 1, 1, true);
|
DECLARE_CUSTOM_OP(to_int64, 1, 1, true, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -75,7 +75,7 @@ namespace nd4j {
|
||||||
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
|
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_to_uint32)
|
#if NOT_EXCLUDED(OP_to_uint32)
|
||||||
DECLARE_OP(to_uint32, 1, 1, true);
|
DECLARE_CUSTOM_OP(to_uint32, 1, 1, true, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -84,7 +84,7 @@ namespace nd4j {
|
||||||
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
|
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_to_uint64)
|
#if NOT_EXCLUDED(OP_to_uint64)
|
||||||
DECLARE_OP(to_uint64, 1, 1, true);
|
DECLARE_CUSTOM_OP(to_uint64, 1, 1, true, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -65,9 +65,15 @@ namespace nd4j {
|
||||||
#if NOT_EXCLUDED(OP_mergemax)
|
#if NOT_EXCLUDED(OP_mergemax)
|
||||||
DECLARE_OP(mergemax, -1, 1, false);
|
DECLARE_OP(mergemax, -1, 1, false);
|
||||||
#endif
|
#endif
|
||||||
|
/*
|
||||||
|
* Complete tensor with max indices merged from all input tensors list
|
||||||
|
*
|
||||||
|
* INPUT: tensors with the same shape
|
||||||
|
* OUTPUT: integer tensor with the same shape
|
||||||
|
* INT_ARG: result type (one of int), INT32 by default
|
||||||
|
*/
|
||||||
#if NOT_EXCLUDED(OP_mergemaxindex)
|
#if NOT_EXCLUDED(OP_mergemaxindex)
|
||||||
DECLARE_OP(mergemaxindex, -1, 1, false);
|
DECLARE_CUSTOM_OP(mergemaxindex, -1, 1, false, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if NOT_EXCLUDED(OP_mergeadd)
|
#if NOT_EXCLUDED(OP_mergeadd)
|
||||||
|
|
|
@ -25,7 +25,7 @@ namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
template <typename X, typename Z>
|
template <typename X, typename Z>
|
||||||
void _CUDA_G histogramKernel(void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, void *allocationPointer, void *reductionPointer, Nd4jLong numBins, double min_val, double max_val) {
|
void _CUDA_G histogramKernel(void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, void *allocationPointer, void *reductionPointer, Nd4jLong numBins, X* min_val, X* max_val) {
|
||||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
auto dx = reinterpret_cast<X*>(xBuffer);
|
auto dx = reinterpret_cast<X*>(xBuffer);
|
||||||
auto result = reinterpret_cast<Z*>(zBuffer);
|
auto result = reinterpret_cast<Z*>(zBuffer);
|
||||||
|
@ -42,19 +42,19 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
Z binSize = (max_val - min_val) / (numBins);
|
X binSize = X((*max_val - *min_val) / numBins);
|
||||||
|
|
||||||
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
|
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
|
||||||
bins[e] = (Z) 0.0f;
|
bins[e] = (Z) 0;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
for (int e = tid; e < length; e+= blockDim.x * gridDim.x) {
|
for (int e = tid; e < length; e += blockDim.x * gridDim.x) {
|
||||||
int idx = (int) ((dx[e] - min_val) / binSize);
|
int idx = int((dx[e] - *min_val) / binSize);
|
||||||
if (idx < 0) idx = 0;
|
idx = math::nd4j_max(idx, 0); //atomicMax(&idx, 0);//atomicMax(&idx, 0);
|
||||||
else if (idx >= numBins) idx = numBins - 1;
|
idx = math::nd4j_min(idx, int(numBins - 1)); //atomicMin(&idx, int(numBins - 1));
|
||||||
|
nd4j::math::atomics::nd4j_atomicAdd(&bins[idx], (Z)1);
|
||||||
nd4j::math::atomics::nd4j_atomicAdd(&bins[idx], (Z) 1.0f);
|
// bins[idx]++;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
|
@ -82,7 +82,7 @@ namespace nd4j {
|
||||||
|
|
||||||
// nullify shared memory for future accumulation
|
// nullify shared memory for future accumulation
|
||||||
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
|
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
|
||||||
bins[e] = (Z) 0.0f;
|
bins[e] = (Z) 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// accumulate reduced bins
|
// accumulate reduced bins
|
||||||
|
@ -90,7 +90,7 @@ namespace nd4j {
|
||||||
Z *ptrBuf = ((Z *)allocationPointer) + (r * numBins);
|
Z *ptrBuf = ((Z *)allocationPointer) + (r * numBins);
|
||||||
|
|
||||||
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
|
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
|
||||||
bins[e] += ptrBuf[e];
|
math::atomics::nd4j_atomicAdd(&bins[e], ptrBuf[e]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
@ -109,24 +109,26 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename X, typename Z>
|
template <typename X, typename Z>
|
||||||
static void histogram_(nd4j::LaunchContext *context, void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong numBins, double min_val, double max_val) {
|
static void histogram_(nd4j::LaunchContext *context, void *xBuffer, Nd4jLong *xShapeInfo, Nd4jLong *dxShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong numBins, void* min_val, void* max_val) {
|
||||||
int numThreads = 256;
|
int numThreads = 256;
|
||||||
int numBlocks = nd4j::math::nd4j_max<int>(256, nd4j::math::nd4j_min<int>(1, shape::length(xShapeInfo) / numThreads));
|
int numBlocks = nd4j::math::nd4j_max<int>(256, nd4j::math::nd4j_min<int>(1, shape::length(xShapeInfo) / numThreads));
|
||||||
int workspaceSize = numBlocks * numBins;
|
int workspaceSize = numBlocks * numBins;
|
||||||
auto tmp = NDArrayFactory::create<Z>('c',{workspaceSize});
|
auto tmp = NDArrayFactory::create<Z>('c', {workspaceSize});
|
||||||
|
|
||||||
histogramKernel<X, Z><<<numBlocks, numThreads, 32768, *context->getCudaStream()>>>(xBuffer, xShapeInfo, zBuffer, zShapeInfo, tmp.getSpecialBuffer(), context->getReductionPointer(), numBins, min_val, max_val);
|
histogramKernel<X, Z><<<numBlocks, numThreads, 32768, *context->getCudaStream()>>>(xBuffer, dxShapeInfo, zBuffer, zShapeInfo, tmp.getSpecialBuffer(), context->getReductionPointer(), numBins, reinterpret_cast<X*>(min_val), reinterpret_cast<X*>(max_val));
|
||||||
|
|
||||||
cudaStreamSynchronize(*context->getCudaStream());
|
cudaStreamSynchronize(*context->getCudaStream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void histogramHelper(nd4j::LaunchContext *context, NDArray &input, NDArray &output) {
|
void histogramHelper(nd4j::LaunchContext *context, NDArray &input, NDArray &output) {
|
||||||
Nd4jLong numBins = output.lengthOf();
|
Nd4jLong numBins = output.lengthOf();
|
||||||
double min_val = input.reduceNumber(reduce::SameOps::Min).e<double>(0);
|
NDArray::registerSpecialUse({&output}, {&input});
|
||||||
double max_val = input.reduceNumber(reduce::SameOps::Max).e<double>(0);
|
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogram_, (context, input.specialBuffer(), input.specialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), numBins, min_val, max_val), LIBND4J_TYPES, INDEXING_TYPES);
|
|
||||||
|
|
||||||
|
auto min_val = input.reduceNumber(reduce::SameOps::Min);
|
||||||
|
auto max_val = input.reduceNumber(reduce::SameOps::Max);
|
||||||
|
// min_val.printIndexedBuffer("MIN");
|
||||||
|
// max_val.printIndexedBuffer("MAX");
|
||||||
|
BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogram_, (context, input.specialBuffer(), input.shapeInfo(), input.specialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), numBins, min_val.specialBuffer(), max_val.specialBuffer()), LIBND4J_TYPES, INTEGER_TYPES);
|
||||||
NDArray::registerSpecialUse({&output}, {&input});
|
NDArray::registerSpecialUse({&output}, {&input});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -68,21 +68,21 @@ namespace helpers {
|
||||||
static __global__ void shouldSelectKernel(T* boxesBuf, Nd4jLong* boxesShape, I* indexBuf, I* selectedIndicesData, double threshold, int numSelected, int i, bool* shouldSelect) {
|
static __global__ void shouldSelectKernel(T* boxesBuf, Nd4jLong* boxesShape, I* indexBuf, I* selectedIndicesData, double threshold, int numSelected, int i, bool* shouldSelect) {
|
||||||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
auto step = gridDim.x * blockDim.x;
|
auto step = gridDim.x * blockDim.x;
|
||||||
__shared__ bool shouldSelectShared;
|
__shared__ unsigned int shouldSelectShared;
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
shouldSelectShared = shouldSelect[0];
|
shouldSelectShared = (unsigned int)shouldSelect[0];
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
for (int j = numSelected - 1 - tid; j >= 0; j -= step) {
|
for (int j = numSelected - 1 - tid; j >= 0; j -= step) {
|
||||||
if (shouldSelectShared) {
|
if (shouldSelectShared) {
|
||||||
if (needToSuppressWithThreshold(boxesBuf, boxesShape, indexBuf[i],
|
if (needToSuppressWithThreshold(boxesBuf, boxesShape, indexBuf[i],
|
||||||
indexBuf[selectedIndicesData[j]], T(threshold)))
|
indexBuf[selectedIndicesData[j]], T(threshold)))
|
||||||
shouldSelectShared = false;
|
atomicCAS(&shouldSelectShared, 1, 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
*shouldSelect = shouldSelectShared;
|
*shouldSelect = shouldSelectShared > 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -48,8 +48,10 @@ namespace nd4j {
|
||||||
auto x = reinterpret_cast<T*>(inArrs[i]);
|
auto x = reinterpret_cast<T*>(inArrs[i]);
|
||||||
auto xShape = reinterpret_cast<Nd4jLong *>(inShapes[i]);
|
auto xShape = reinterpret_cast<Nd4jLong *>(inShapes[i]);
|
||||||
auto val = x[shape::getIndexOffset(e, xShape, length)];;
|
auto val = x[shape::getIndexOffset(e, xShape, length)];;
|
||||||
if (mVal < val)
|
if (mVal < val) {
|
||||||
mIdx = static_cast<Z>(e);
|
mIdx = static_cast<Z>(i);
|
||||||
|
mVal = val;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
|
|
|
@ -123,14 +123,236 @@ namespace nd4j {
|
||||||
nSamplingKernel<T><<<1,1,128, *stream>>>(vsyn0, vsyn1Neg, vexpTable, vneu1e, alpha, vectorLength, code, expLength, isInference);
|
nSamplingKernel<T><<<1,1,128, *stream>>>(vsyn0, vsyn1Neg, vexpTable, vneu1e, alpha, vectorLength, code, expLength, isInference);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* binarySearch - find element in haystack buffer (haystack - sorted device memory)
|
||||||
|
* */
|
||||||
int binarySearch(const int *haystack, const int needle, const int totalElements) {
|
int binarySearch(const int *haystack, const int needle, const int totalElements) {
|
||||||
return 0;
|
int firstIndex = 0;
|
||||||
|
int lastIndex = totalElements - 1;
|
||||||
|
int halfIndex = nd4j::math::nd4j_floor<float, int>((lastIndex + firstIndex) / (float) 2);
|
||||||
|
|
||||||
|
while(haystack[halfIndex] != needle && firstIndex < lastIndex) {
|
||||||
|
if (needle < haystack[halfIndex]) {
|
||||||
|
lastIndex = halfIndex - 1;
|
||||||
|
} else if (needle > haystack[halfIndex]) {
|
||||||
|
firstIndex = halfIndex + 1;
|
||||||
|
}
|
||||||
|
halfIndex = nd4j::math::nd4j_floor<float, int>((lastIndex + firstIndex) / (float) 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (haystack[halfIndex] == needle) ? halfIndex : -1;
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
__global__ void addInfVectorKernel(T* neu1, T* infVector, int vectorLength) {
|
||||||
|
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
|
||||||
|
for (auto i = start; i < vectorLength; i += step) {
|
||||||
|
neu1[i] += infVector[i];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void skipgram(NDArray &syn0, NDArray &syn1, NDArray &syn1Neg, NDArray &expTable, NDArray &negTable, NDArray &target, NDArray &ngStarter, int nsRounds, NDArray &indices, NDArray &codes, NDArray &alpha, NDArray &randomValue, NDArray &inferenceVector, const bool preciseMode, const int numWorkers) {
|
template <typename T>
|
||||||
|
void skipgram_(NDArray& s0, NDArray& s1, NDArray& s1n, NDArray& expTableV, NDArray& negTableV, NDArray& infV, int target, int ngStarter, NDArray& indices, NDArray& codes, double alpha, Nd4jLong randomValue, const int hsRounds, const int nsRounds) {
|
||||||
|
// void *vsyn0, void *vsyn1, void *vsyn1Neg, void *vexpTable, void *vnegTable, void *vinfVector, int target, int ngStarter, int *indices, int8_t *codes, double alpha, Nd4jLong randomValue, const int hsRounds, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength) {
|
||||||
|
auto syn0 = reinterpret_cast<T*>(s0.specialBuffer());
|
||||||
|
auto syn1 = reinterpret_cast<T*>(s1.specialBuffer());
|
||||||
|
auto syn1Neg = reinterpret_cast<T*>(s1n.specialBuffer());
|
||||||
|
auto expTable = reinterpret_cast<T*>(expTableV.specialBuffer());
|
||||||
|
auto negTable = reinterpret_cast<T*>(negTableV.specialBuffer());
|
||||||
|
auto infVector = reinterpret_cast<T*>(infV.specialBuffer());
|
||||||
|
const int vocabSize = s0.sizeAt(0);
|
||||||
|
const int vectorLength = s0.sizeAt(1);
|
||||||
|
const int expLength = expTableV.lengthOf();
|
||||||
|
const int negLength = negTableV.lengthOf();
|
||||||
|
indices.tickReadDevice();
|
||||||
|
indices.syncToHost();
|
||||||
|
codes.tickReadDevice();
|
||||||
|
codes.syncToHost();
|
||||||
|
auto stream = s0.getContext()->getCudaStream();
|
||||||
|
|
||||||
|
T* neu1e; // = new T[vectorLength];
|
||||||
|
//memset(neu1e, 0, vectorLength * sizeof(T));
|
||||||
|
auto err = cudaMalloc(&neu1e, sizeof(T) * vectorLength);
|
||||||
|
err = cudaMemset(neu1e, 0, sizeof(T) * vectorLength);
|
||||||
|
// hierarchic softmax goes first (if enabled)
|
||||||
|
|
||||||
|
auto syn0row = infVector != nullptr ? infVector : syn0 + (target * vectorLength);
|
||||||
|
auto irow = 0;
|
||||||
|
if (hsRounds > 0) {
|
||||||
|
for (int r = 0; r < hsRounds; r++) {
|
||||||
|
irow = indices.t<int>(r);
|
||||||
|
if (irow < 0 || irow >= vocabSize)
|
||||||
|
break;
|
||||||
|
|
||||||
|
hSoftmax_<T>(syn0row, syn1 + (irow * vectorLength), expTable, neu1e, alpha, vectorLength, codes.t<int8_t>(r), expLength, infVector != nullptr, stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// negative sampling goes second (if enabled)
|
||||||
|
auto nsStarter = ngStarter;
|
||||||
|
irow = nsStarter;
|
||||||
|
if (nsRounds > 0) {
|
||||||
|
for (int r = 0; r < nsRounds + 1; r++) {
|
||||||
|
if (r == 0) {
|
||||||
|
// target is known in advance
|
||||||
|
} else {
|
||||||
|
randomValue = randomValue * (unsigned long long) 25214903917 + 11;
|
||||||
|
auto idx = nd4j::math::nd4j_abs<Nd4jLong >((randomValue >> 16) % negLength);
|
||||||
|
irow = idx >= negLength ? -1 : negTableV.e<int>(idx);
|
||||||
|
|
||||||
|
if (irow < 0 || irow >= vocabSize) irow = randomValue % (vocabSize - 1) + 1;
|
||||||
|
if (irow == nsStarter)
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
nSampling_<T>(syn0row, syn1Neg + (irow * vectorLength), expTable, neu1e, alpha, vectorLength, r == 0 ? 1 : 0, expLength, infVector != nullptr, stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (infVector == nullptr) {
|
||||||
|
addInfVectorKernel<T><<<128, 256, 256, *stream>>>(syn0row, neu1e, vectorLength);
|
||||||
|
} else {
|
||||||
|
addInfVectorKernel<T><<<128, 256, 256, *stream>>>(infVector, neu1e, vectorLength);
|
||||||
|
}
|
||||||
|
|
||||||
|
err = cudaFree(neu1e);
|
||||||
|
if (0 != err) {
|
||||||
|
throw cuda_exception::build("helpers::skipgram_: Cannot deallocate temp memory for lingual net", err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void skipgram_, (NDArray& syn0, NDArray& syn1, NDArray& syn1Neg, NDArray& expTable, NDArray& negTable, NDArray& infVector, int target, int ngStarter, NDArray& indices, NDArray& codes, double alpha, Nd4jLong randomValue, const int hsRounds, const int nsRounds), FLOAT_TYPES);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* batched version of skipgram routine
|
||||||
|
* */
|
||||||
|
template <typename T>
|
||||||
|
void skipgramBatchExec_(NDArray &s0, NDArray &s1, NDArray &s1n, NDArray& expTableV, NDArray& negTableV, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, const int nsRounds, const bool preciseMode, const int numThreads) {
|
||||||
|
// (NDArray &s0, NDArray &s1, NDArray &s1n, NDArray& expTable, NDArray& negTable, NDArray& infVector, NDArray& targets, NDArray& negStarters, NDArray& indices, NDArray& codes, NDArray& lr, NDArray& nextRandom, const int nsRounds, const bool preciseMode, const int numThreads) {
|
||||||
|
//auto syn0 = reinterpret_cast<T*>(vsyn0);
|
||||||
|
//auto syn1 = reinterpret_cast<T*>(vsyn1);
|
||||||
|
//auto syn1Neg = reinterpret_cast<T*>(vsyn1Neg);
|
||||||
|
auto stream = s0.getContext()->getCudaStream();
|
||||||
|
negTableV.tickReadDevice();
|
||||||
|
negTableV.syncToHost();
|
||||||
|
const auto expTable = reinterpret_cast<T*>(expTableV.specialBuffer());
|
||||||
|
const auto negTable = reinterpret_cast<T*>(negTableV.buffer());
|
||||||
|
const auto infVector = (T*)nullptr; //reinterpret_cast<T*>(infVector.specialBuffer());
|
||||||
|
|
||||||
|
const int vocabSize = s0.sizeAt(0);
|
||||||
|
const int vectorLength = s0.sizeAt(1);
|
||||||
|
const int expLength = expTableV.lengthOf();
|
||||||
|
const int negLength = negTableV.lengthOf();
|
||||||
|
|
||||||
|
//T sneu1e[600];
|
||||||
|
|
||||||
|
//const auto numThreads = omp_get_max_threads();
|
||||||
|
const auto idxShift = indices.isEmpty() ? 0 : indices.sizeAt(1);
|
||||||
|
const auto hsRounds = codes.isEmpty() ? 0 : codes.sizeAt(1);
|
||||||
|
|
||||||
|
// regular mode provides 0 guarantees for reproducibility
|
||||||
|
auto numTargets = targets.lengthOf();
|
||||||
|
targets.syncToHost();
|
||||||
|
indices.syncToHost();
|
||||||
|
codes.syncToHost();
|
||||||
|
lr.syncToHost();
|
||||||
|
nextRandom.syncToHost();
|
||||||
|
negStarters.tickReadDevice();
|
||||||
|
negStarters.syncToHost();
|
||||||
|
auto bTarget = reinterpret_cast<int*>(targets.buffer()); //targets.bufferAsT<int>();
|
||||||
|
auto bIndices = reinterpret_cast<int*>(indices.buffer()); //indices.bufferAsT<int>();
|
||||||
|
auto bCodes = reinterpret_cast<int8_t*>(codes.buffer()); //codes.bufferAsT<int8_t>();
|
||||||
|
|
||||||
|
// PRAGMA_OMP_PARALLEL_FOR_ARGS(num_threads(numThreads))
|
||||||
|
for (int t = 0; t < numTargets; t++) {
|
||||||
|
T* neu1e;//lvectorLength <= 600 ? sneu1e : new T[vectorLength];
|
||||||
|
auto err = cudaMalloc(&neu1e, vectorLength * sizeof(T));
|
||||||
|
err = cudaMemset(neu1e, 0, vectorLength * sizeof(T));
|
||||||
|
//memset(neu1e, 0, vectorLength * sizeof(T));
|
||||||
|
|
||||||
|
auto target = bTarget[t];
|
||||||
|
auto alpha = lr.e<double>(t);
|
||||||
|
unsigned long long randomValue = nextRandom.e<Nd4jLong>(t);
|
||||||
|
|
||||||
|
auto syn0row = reinterpret_cast<T*>(s0.specialBuffer()) + (target * vectorLength);
|
||||||
|
|
||||||
|
if (hsRounds > 0) {
|
||||||
|
int irow = 0;
|
||||||
|
auto cShift = t * idxShift;
|
||||||
|
|
||||||
|
for (int e = 0; e < hsRounds; e++) {
|
||||||
|
irow = bIndices[e + cShift];
|
||||||
|
if (irow < 0 || irow >= vocabSize)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
auto syn1row = reinterpret_cast<T*>(s1.getSpecialBuffer()) + (irow * vectorLength);
|
||||||
|
auto code = bCodes[e + cShift];
|
||||||
|
|
||||||
|
//nd4j_printf("syn0: [%i]; syn1: [%i]; code: [%i]\n", target, irow, code);
|
||||||
|
hSoftmax_<T>(syn0row, syn1row, expTable, neu1e, alpha, vectorLength, code, expLength, false, stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if (nsRounds > 0) {
|
||||||
|
int irow = negStarters.e<int>(t);
|
||||||
|
int nsStarter = irow;
|
||||||
|
for (int r = 0; r < nsRounds + 1; r++) {
|
||||||
|
if (r == 0) {
|
||||||
|
// target is known in advance
|
||||||
|
} else {
|
||||||
|
randomValue = randomValue * (unsigned long long) 25214903917 + 11;
|
||||||
|
auto idx = nd4j::math::nd4j_abs<Nd4jLong >((randomValue >> 16) % negLength);
|
||||||
|
irow = idx >= negLength ? -1 : static_cast<int>(negTable[idx]);
|
||||||
|
|
||||||
|
if (irow < 0 || irow >= vocabSize)
|
||||||
|
irow = randomValue % (vocabSize - 1) + 1;
|
||||||
|
|
||||||
|
if (irow == nsStarter)
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto syn1row = reinterpret_cast<T*>(s1n.getSpecialBuffer()) + (irow * vectorLength);
|
||||||
|
|
||||||
|
nSampling_<T>(syn0row, syn1row, expTable, neu1e, alpha, vectorLength, r == 0 ? 1 : 0, expLength, false, stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
addInfVectorKernel<T><<<128, 256, 256, *stream>>>(syn0row, neu1e, vectorLength);
|
||||||
|
|
||||||
|
// optionally release temp arrays
|
||||||
|
err = cudaFree(neu1e);
|
||||||
|
if (err != 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
// if (vectorLength > 600)
|
||||||
|
// delete[] neu1e;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void skipgramBatchExec_, (NDArray &s0, NDArray &s1, NDArray &s1n, NDArray& expTable, NDArray& negTable, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, const int nsRounds, const bool preciseMode, const int numThreads), FLOAT_TYPES);
|
||||||
|
|
||||||
|
void skipgram(NDArray &syn0, NDArray &syn1, NDArray &syn1Neg, NDArray &expTable, NDArray &negTable,
|
||||||
|
NDArray &target, NDArray &ngStarter, int nsRounds, NDArray &indices, NDArray &codes, NDArray &alpha, NDArray &randomValue, NDArray &inferenceVector, const bool preciseMode, const int numWorkers) {
|
||||||
auto xType = syn0.dataType();
|
auto xType = syn0.dataType();
|
||||||
|
// single round case
|
||||||
|
if ((ngStarter.isScalar() && !ngStarter.isEmpty())|| (target.isScalar() && !target.isEmpty())) {
|
||||||
|
auto hsRounds = codes.lengthOf();
|
||||||
|
target.syncToHost();
|
||||||
|
ngStarter.syncToHost();
|
||||||
|
alpha.syncToHost();
|
||||||
|
randomValue.syncToHost();
|
||||||
|
|
||||||
|
auto targetV = target.isEmpty() ? -1 : target.e<int>(0);
|
||||||
|
auto starterV = ngStarter.isEmpty() ? -1 : ngStarter.e<int>(0);
|
||||||
|
auto alphaV = alpha.e<double>(0);
|
||||||
|
auto randomV = randomValue.e<Nd4jLong>(0);
|
||||||
|
BUILD_SINGLE_SELECTOR(xType, skipgram_, (syn0, syn1, syn1Neg, expTable, negTable, inferenceVector, targetV, starterV, indices, codes, alphaV, randomV, hsRounds, nsRounds), FLOAT_TYPES);
|
||||||
|
} else if (ngStarter.isVector() || target.isVector()){
|
||||||
|
// batch mode
|
||||||
|
// NDArray& infVector, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, const int nsRounds, const bool preciseMode, const int numThreads)
|
||||||
|
BUILD_SINGLE_SELECTOR(xType, skipgramBatchExec_, (syn0, syn1, syn1Neg, expTable, negTable, target, ngStarter, indices, codes, alpha, randomValue, nsRounds, preciseMode, numWorkers), FLOAT_TYPES);
|
||||||
|
} else
|
||||||
|
throw std::runtime_error("SkipGram: target must have rank 0 or 1");
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static __global__ void checkContextKernel(int* context, T* syn0, T* neu1, int contextWidth, int vectorLength, int vocabSize) {
|
static __global__ void checkContextKernel(int* context, T* syn0, T* neu1, int contextWidth, int vectorLength, int vocabSize) {
|
||||||
__shared__ bool hasError;
|
__shared__ bool hasError;
|
||||||
|
@ -157,16 +379,6 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__global__ void addInfVectorKernel(T* neu1, T* infVector, int vectorLength) {
|
|
||||||
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
|
||||||
auto step = blockDim.x * gridDim.x;
|
|
||||||
|
|
||||||
for (auto i = start; i < vectorLength; i += step) {
|
|
||||||
neu1[i] += infVector[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void shiftKernel(T* neu1, T* infVector, int contextWidth, int vectorLength) {
|
__global__ void shiftKernel(T* neu1, T* infVector, int contextWidth, int vectorLength) {
|
||||||
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
|
@ -26,7 +26,13 @@ namespace nd4j {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
template<typename T>
|
template<typename T>
|
||||||
void toggle_bits__(NDArray &in, NDArray &out) {
|
void toggle_bits__(NDArray &in, NDArray &out) {
|
||||||
|
NDArray::prepareSpecialUse({&out}, {&in});
|
||||||
|
auto lambda = LAMBDA_T(_x) {
|
||||||
|
return ~_x;//eUtils::flip_bits(_x);
|
||||||
|
};
|
||||||
|
|
||||||
|
in.applyLambda(lambda, &out);
|
||||||
|
NDArray::registerSpecialUse({&out}, {&in});
|
||||||
}
|
}
|
||||||
BUILD_SINGLE_TEMPLATE(template void toggle_bits__, (NDArray &in, NDArray &out), INTEGER_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void toggle_bits__, (NDArray &in, NDArray &out), INTEGER_TYPES);
|
||||||
|
|
||||||
|
|
|
@ -685,13 +685,12 @@ void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDAr
|
||||||
BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace), LIBND4J_TYPES);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void eye(nd4j::LaunchContext * context, NDArray& output) {
|
void eye(nd4j::LaunchContext * context, NDArray& output) {
|
||||||
|
|
||||||
output.setIdentity();
|
output.setIdentity();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static __global__ void clipByNormInplaceKernel(Nd4jLong numOfSubArrs, T* inputBuffer, Nd4jLong* shape, Nd4jLong* inputOffsets, T* norm2Buf, Nd4jLong* norm2shape, T clipNorm) {
|
static __global__ void clipByNormInplaceKernel(Nd4jLong numOfSubArrs, T* inputBuffer, Nd4jLong* shape, Nd4jLong* inputOffsets, T* norm2Buf, Nd4jLong* norm2shape, T clipNorm) {
|
||||||
|
|
|
@ -502,6 +502,7 @@ TEST_F(DeclarableOpsTests2, Test_FloorDiv_2) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {1, 3}, {3.0, 6.0, -3.0});
|
auto x = NDArrayFactory::create<float>('c', {1, 3}, {3.0, 6.0, -3.0});
|
||||||
auto y = NDArrayFactory::create<float>('c', {1, 3}, {-2.0, 2.0, -2.0});
|
auto y = NDArrayFactory::create<float>('c', {1, 3}, {-2.0, 2.0, -2.0});
|
||||||
auto eps = NDArrayFactory::create<float>('c', {1, 3}, {1, 2, 3});
|
auto eps = NDArrayFactory::create<float>('c', {1, 3}, {1, 2, 3});
|
||||||
|
|
||||||
auto exp1 = NDArrayFactory::create<float>('c', {1, 3}, {0.f, 0.f, 0.f});
|
auto exp1 = NDArrayFactory::create<float>('c', {1, 3}, {0.f, 0.f, 0.f});
|
||||||
auto exp2 = NDArrayFactory::create<float>('c', {1, 3}, {0.f, 0.f, 0.f});
|
auto exp2 = NDArrayFactory::create<float>('c', {1, 3}, {0.f, 0.f, 0.f});
|
||||||
|
|
||||||
|
|
|
@ -223,21 +223,221 @@ TEST_F(DeclarableOpsTests5, Test_Boolean_diff_1) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests5, Test_SetSeed_1) {
|
||||||
|
auto x = NDArrayFactory::create<int>('c', {1, 1}, {120});
|
||||||
|
auto y = NDArrayFactory::create<int>(5);
|
||||||
|
|
||||||
|
nd4j::ops::set_seed op;
|
||||||
|
auto result = op.execute({&x, &y}, {}, {120, 5}, {}, false, nd4j::DataType::INT32);
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
// result->at(0)->printIndexedBuffer("RES SEED");
|
||||||
|
nd4j::ops::get_seed getOp;
|
||||||
|
auto getRes = getOp.execute({}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), getRes->status());
|
||||||
|
// getRes->at(0)->printIndexedBuffer("Output RES GET SEED");
|
||||||
|
// ASSERT_EQ(result->at(0)->t<bool>(0), true);
|
||||||
|
delete result;
|
||||||
|
delete getRes;
|
||||||
|
}
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests5, scatterMul_test1) {
|
||||||
|
auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
||||||
|
NDArray idc('c', {1}, {0LL}, nd4j::DataType::INT64);
|
||||||
|
auto updates = NDArrayFactory::create<float>('c', {1, 2}, {10, 1});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {10, 2, 3, 4});
|
||||||
|
|
||||||
|
nd4j::ops::scatter_mul op;
|
||||||
|
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests5, scatterDiv_test1) {
|
||||||
|
auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
||||||
|
NDArray idc('c', {1}, {0LL}, nd4j::DataType::INT64);
|
||||||
|
auto updates = NDArrayFactory::create<float>('c', {1, 2}, {10, 1});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {0.10, 2, 3, 4});
|
||||||
|
|
||||||
|
nd4j::ops::scatter_div op;
|
||||||
|
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
// z->printIndexedBuffer("Scatter Div");
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests5, scatterSub_test1) {
|
||||||
|
auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
||||||
|
NDArray idc('c', {1}, {0LL}, nd4j::DataType::INT64);
|
||||||
|
auto updates = NDArrayFactory::create<float>('c', {1, 2}, {10, 1});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {-9, 1, 3, 4});
|
||||||
|
|
||||||
|
nd4j::ops::scatter_sub op;
|
||||||
|
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
// z->printIndexedBuffer("Scatter Sub");
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests5, hardsigmoid_test1) {
|
||||||
|
auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {0.7, 0.9, 1, 1});
|
||||||
|
|
||||||
|
nd4j::ops::hardsigmoid op;
|
||||||
|
auto result = op.execute({&matrix}, {}, {}, {});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
z->printIndexedBuffer("Hadrdsigmoid 2x2");
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests5, hardsigmoid_test2) {
|
||||||
|
auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
||||||
|
auto eps = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {0.2, 0.4, 0, 0});
|
||||||
|
|
||||||
|
nd4j::ops::hardsigmoid_bp op;
|
||||||
|
auto result = op.execute({&matrix, &eps}, {}, {}, {});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
z->printIndexedBuffer("Hadrdsigmoid 2x2");
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests5, hardtanh_test1) {
|
||||||
|
auto matrix = NDArrayFactory::create<float>('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {-1, -1, -1, -1, 0, 1, 1, 1, 1});
|
||||||
|
|
||||||
|
nd4j::ops::hardtanh op;
|
||||||
|
auto result = op.execute({&matrix}, {}, {}, {});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
// z->printIndexedBuffer("Hardtanh 2x2");
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests5, hardtanh_test2) {
|
||||||
|
auto matrix = NDArrayFactory::create<float>('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4});
|
||||||
|
auto eps = NDArrayFactory::create<float>('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {0, 0, 0, 4, 5, 6, 0, 0, 0});
|
||||||
|
|
||||||
|
nd4j::ops::hardtanh_bp op;
|
||||||
|
auto result = op.execute({&matrix, &eps}, {}, {}, {});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
// z->printIndexedBuffer("Hardtanh_bp 2x2");
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests5, histogram_test1) {
|
||||||
|
auto matrix = NDArrayFactory::create<double>('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4});
|
||||||
|
auto exp = NDArrayFactory::create<Nd4jLong>('c', {3}, {3, 3, 3});
|
||||||
|
|
||||||
|
nd4j::ops::histogram op;
|
||||||
|
auto result = op.execute({&matrix}, {}, {3}, {});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
// z->printIndexedBuffer("Histogram3");
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests5, histogram_test2) {
|
||||||
|
auto matrix = NDArrayFactory::create<double>('c', {3}, {1, 2, 1});
|
||||||
|
auto exp = NDArrayFactory::create<Nd4jLong>('c', {4}, {2, 0, 0, 1});
|
||||||
|
|
||||||
|
nd4j::ops::histogram op;
|
||||||
|
auto result = op.execute({&matrix}, {}, {4}, {});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
z->printIndexedBuffer("Histogram4");
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests5, Identity_test1) {
|
||||||
|
auto matrix = NDArrayFactory::create<float>('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4});
|
||||||
|
// auto exp = NDArrayFactory::create<Nd4jLong>('c', {3, 3}, {3, 3, 3});
|
||||||
|
|
||||||
|
nd4j::ops::identity op;
|
||||||
|
auto result = op.execute({&matrix}, {}, {}, {});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
// z->printIndexedBuffer("Histogram3");
|
||||||
|
ASSERT_TRUE(matrix.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests5, Identity_test2) {
|
||||||
|
auto matrix = NDArrayFactory::create<float>('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4});
|
||||||
|
auto eps = NDArrayFactory::create<float>('c', {3, 3}, {1,2,3,4,5,6,7,8,9});
|
||||||
|
// auto exp = NDArrayFactory::create<float>('c', {3,3});
|
||||||
|
nd4j::ops::identity_bp op;
|
||||||
|
auto result = op.execute({&matrix, &eps}, {}, {}, {});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
z->printIndexedBuffer("Identity_BP");
|
||||||
|
ASSERT_TRUE(z->equalsTo(eps));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests5, Log1p_test1) {
|
||||||
|
auto matrix = NDArrayFactory::create<float>('c', {3, 3}, {4, 3, 2, 1, 0, 1, 2, 3, 4});
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {3,3}, {5,4,3,2,1,2,3,4,5});
|
||||||
|
// auto eps = NDArrayFactory::create<float>('c', {3, 3}, {1,2,3,4,5,6,7,8,9});
|
||||||
|
// auto exp = NDArrayFactory::create<float>('c', {3,3});
|
||||||
|
nd4j::ops::Log1p op;
|
||||||
|
y.applyTransform(nd4j::transform::Log, nullptr, nullptr);
|
||||||
|
auto result = op.execute({&matrix}, {}, {}, {});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
z->printIndexedBuffer("Log1p");
|
||||||
|
ASSERT_TRUE(z->equalsTo(y));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_1) {
|
TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_1) {
|
||||||
|
|
||||||
|
|
|
@ -737,6 +737,44 @@ TEST_F(DeclarableOpsTests6, cumSum_20) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_1) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
|
||||||
|
auto y = NDArrayFactory::create<double>('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f});
|
||||||
|
auto z = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 20.f, 3.f, 40.f, 5.f, 60.f, 7.f, 80.f});
|
||||||
|
auto exp = NDArrayFactory::create<int>('c', {2, 2, 2}, {1, 2, 1, 2, 1, 2, 1, 2});
|
||||||
|
nd4j::ops::mergemaxindex op;
|
||||||
|
|
||||||
|
auto ress = op.execute({&x, &y, &z}, {}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
||||||
|
// ress->at(0)->printIndexedBuffer("MergeMaxIndex Result is ");
|
||||||
|
// ress->at(0)->printShapeInfo("Shape info for MergeMaxIdex");
|
||||||
|
// x.printIndexedBuffer("Input is");
|
||||||
|
ASSERT_TRUE(ress->at(0)->equalsTo(exp));
|
||||||
|
delete ress;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_2) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
|
||||||
|
auto y = NDArrayFactory::create<double>('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f});
|
||||||
|
auto z = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 20.f, 3.f, 40.f, 5.f, 60.f, 7.f, 80.f});
|
||||||
|
auto exp = NDArrayFactory::create<Nd4jLong>('c', {2, 2, 2}, {1, 2, 1, 2, 1, 2, 1, 2});
|
||||||
|
nd4j::ops::mergemaxindex op;
|
||||||
|
|
||||||
|
auto ress = op.execute({&x, &y, &z}, {}, {nd4j::DataType::INT64}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
||||||
|
// ress->at(0)->printIndexedBuffer("MergeMaxIndex2 Result is ");
|
||||||
|
// ress->at(0)->printShapeInfo("Shape info for MergeMaxIdex2");
|
||||||
|
// x.printIndexedBuffer("Input is");
|
||||||
|
ASSERT_TRUE(ress->at(0)->equalsTo(exp));
|
||||||
|
delete ress;
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests6, TestDropout_1) {
|
TEST_F(DeclarableOpsTests6, TestDropout_1) {
|
||||||
|
|
||||||
|
@ -752,8 +790,60 @@ TEST_F(DeclarableOpsTests6, TestDropout_1) {
|
||||||
|
|
||||||
delete ress;
|
delete ress;
|
||||||
}
|
}
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests6, TestMod_1) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
|
||||||
|
auto y = NDArrayFactory::create<double>('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {2, 2, 2}, {1, 0, 3, 0, 5, 0, 7, 0});
|
||||||
|
nd4j::ops::mod op;
|
||||||
|
|
||||||
|
auto ress = op.execute({&x, &y}, {}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
||||||
|
// ress->at(0)->printIndexedBuffer("MOD Result is ");
|
||||||
|
// x.printIndexedBuffer("Input is");
|
||||||
|
ASSERT_TRUE(ress->at(0)->equalsTo(exp));
|
||||||
|
delete ress;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests6, TestMod_BP_1) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
|
||||||
|
auto y = NDArrayFactory::create<double>('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f});
|
||||||
|
auto eps = NDArrayFactory::create<double>('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {2, 2, 2});
|
||||||
|
nd4j::ops::mod_bp op;
|
||||||
|
|
||||||
|
auto ress = op.execute({&x, &y, &eps}, {}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
||||||
|
// ress->at(0)->printIndexedBuffer("MOD_BP Result is ");
|
||||||
|
|
||||||
|
// x.printIndexedBuffer("Input is");
|
||||||
|
ASSERT_TRUE(ress->at(0)->equalsTo(exp));
|
||||||
|
delete ress;
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests6, TestRank_1) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
|
||||||
|
auto y = NDArrayFactory::create<double>('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f});
|
||||||
|
auto eps = NDArrayFactory::create<double>('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f});
|
||||||
|
auto exp = NDArrayFactory::create<int>(3);
|
||||||
|
nd4j::ops::rank op;
|
||||||
|
|
||||||
|
auto ress = op.execute({&x}, {}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
|
||||||
|
ress->at(0)->printIndexedBuffer("RANK Result is ");
|
||||||
|
|
||||||
|
// x.printIndexedBuffer("Input is");
|
||||||
|
ASSERT_TRUE(ress->at(0)->equalsTo(exp));
|
||||||
|
delete ress;
|
||||||
|
}
|
||||||
TEST_F(DeclarableOpsTests6, TestDropout_2) {
|
TEST_F(DeclarableOpsTests6, TestDropout_2) {
|
||||||
// auto x0 = NDArrayFactory::create<double>('c', {10, 10});
|
// auto x0 = NDArrayFactory::create<double>('c', {10, 10});
|
||||||
// auto x1 = NDArrayFactory::create<double>('c', {10, 10});
|
// auto x1 = NDArrayFactory::create<double>('c', {10, 10});
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include <helpers/helper_hash.h>
|
#include <helpers/helper_hash.h>
|
||||||
#include <NDArray.h>
|
#include <NDArray.h>
|
||||||
#include <array/NDArrayList.h>
|
#include <array/NDArrayList.h>
|
||||||
|
#include <GradCheck.h>
|
||||||
|
|
||||||
|
|
||||||
using namespace nd4j;
|
using namespace nd4j;
|
||||||
|
@ -3605,6 +3606,289 @@ TEST_F(DeclarableOpsTests7, transpose_test3) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests7, rationaltanh_test1) {
|
||||||
|
|
||||||
|
auto input = NDArrayFactory::create<double>('c', {8}, {0, 1, 2, 3, 4, 5, 6, 7});
|
||||||
|
NDArray exp = NDArrayFactory::create<double>({0.000000, 0.998222, 1.516093, 1.658054, 1.695077, 1.706884, 1.711427, 1.713446});
|
||||||
|
|
||||||
|
nd4j::ops::rationaltanh op;
|
||||||
|
auto result = op.execute({&input}, {}, {});
|
||||||
|
auto output = result->at(0);
|
||||||
|
// output->printIndexedBuffer("Output rationaltanh");
|
||||||
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests7, rationaltanh_test2) {
|
||||||
|
|
||||||
|
auto input = NDArrayFactory::create<double>('c', {2,2,2}, {0, 1, 2, 3, 4, 5, 6, 7});
|
||||||
|
NDArray exp = NDArrayFactory::create<double>('c', {2,2,2}, {0.000000, 0.998222, 1.516093, 1.658054, 1.695077, 1.706884, 1.711427, 1.713446});
|
||||||
|
|
||||||
|
nd4j::ops::rationaltanh op;
|
||||||
|
auto result = op.execute({&input}, {}, {});
|
||||||
|
auto output = result->at(0);
|
||||||
|
// output->printIndexedBuffer("Output rationaltanh");
|
||||||
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests7, rationaltanh_test3) {
|
||||||
|
|
||||||
|
auto input = NDArrayFactory::create<double>('c', {2,2,2}, {0, 1, 2, 3, 4, 5, 6, 7});
|
||||||
|
auto eps = NDArrayFactory::create<double>('c', {2,2,2}, {1, 2, 3, 4, 5, 6, 7, 8});
|
||||||
|
NDArray exp = NDArrayFactory::create<double>('c', {2,2,2}, {1.143933, 1.605747, 0.795557, 0.261710, 0.095832, 0.041218, 0.020221, 0.010971});
|
||||||
|
|
||||||
|
nd4j::ops::rationaltanh_bp op;
|
||||||
|
auto result = op.execute({&input, &eps}, {}, {});
|
||||||
|
auto output = result->at(0);
|
||||||
|
// output->printBuffer("Output rationaltanh BP");
|
||||||
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests7, rectifiedtanh_test1) {
|
||||||
|
|
||||||
|
auto input = NDArrayFactory::create<double>('c', {2,2,2}, {0, 1, 2, 3, 4, 5, 6, 7});
|
||||||
|
NDArray exp = NDArrayFactory::create<double>('c', {2,2,2}, {0.000000, 0.761594, 0.964028, 0.995055, 0.999329, 0.999909, 0.999988, 0.999998});
|
||||||
|
|
||||||
|
nd4j::ops::rectifiedtanh op;
|
||||||
|
auto result = op.execute({&input}, {}, {});
|
||||||
|
auto output = result->at(0);
|
||||||
|
// output->printIndexedBuffer("Output rectifiedtanh");
|
||||||
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests7, rectifiedtanh_test2) {
|
||||||
|
|
||||||
|
auto input = NDArrayFactory::create<double>('c', {2,2,2}, {0, 1, 2, 3, 4, 5, 6, 7});
|
||||||
|
auto eps = NDArrayFactory::create<double>('c', {2,2,2}, {1, 2, 3, 4, 5, 6, 7, 8});
|
||||||
|
NDArray exp = NDArrayFactory::create<double>('c', {2,2,2}, {0.000000, 0.839949, 0.211952, 0.039464, 0.006705, 0.001089, 0.000172, 0.000027});
|
||||||
|
|
||||||
|
nd4j::ops::rectifiedtanh_bp op;
|
||||||
|
auto result = op.execute({&input, &eps}, {}, {});
|
||||||
|
auto output = result->at(0);
|
||||||
|
// output->printBuffer("Output rectifiedtanh BP");
|
||||||
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests7, RealDiv_1) {
|
||||||
|
|
||||||
|
NDArray x = NDArrayFactory::create<float>('c', {1, 2, 1}, {2, 4});
|
||||||
|
NDArray y = NDArrayFactory::create<float>('c', {1, 2}, {1,2});
|
||||||
|
NDArray e = NDArrayFactory::create<float>('c', {1, 2, 2}, {2, 1, 4, 2});
|
||||||
|
|
||||||
|
nd4j::ops::realdiv op;
|
||||||
|
auto result = op.execute({&x, &y}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
// z->printIndexedBuffer("OUtput RealDiv");
|
||||||
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
|
ASSERT_TRUE(e.equalsTo(*z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests7, RealDiv_BP_1) {
|
||||||
|
|
||||||
|
NDArray x = NDArrayFactory::create<float>('c', {1, 2, 1}, {2, 4});
|
||||||
|
NDArray y = NDArrayFactory::create<float>('c', {1, 2}, {1,2});
|
||||||
|
NDArray e0 = NDArrayFactory::create<float>('c', {1, 2, 1}, {2, 5});
|
||||||
|
NDArray e1 = NDArrayFactory::create<float>('c', {1, 2}, {-14, -5});
|
||||||
|
NDArray eps = NDArrayFactory::create<float>('c', {1, 2, 2}, {1, 2, 3, 4});
|
||||||
|
|
||||||
|
nd4j::ops::realdiv_bp op;
|
||||||
|
auto result = op.execute({&x, &y, &eps}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z0 = result->at(0);
|
||||||
|
auto z1 = result->at(1);
|
||||||
|
// z0->printShapeInfo("OUtput RealDiv BP0 shape");
|
||||||
|
// z1->printShapeInfo("OUtput RealDiv BP1 shape");
|
||||||
|
// z0->printIndexedBuffer("OUtput RealDiv BP0");
|
||||||
|
// z1->printIndexedBuffer("OUtput RealDiv BP1");
|
||||||
|
// ASSERT_TRUE(e.isSameShape(z));
|
||||||
|
ASSERT_TRUE(e0.equalsTo(z0));
|
||||||
|
ASSERT_TRUE(e1.equalsTo(z1));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests7, ShapesOf_1) {
|
||||||
|
|
||||||
|
NDArray x = NDArrayFactory::create<float>('c', {1, 2, 1}, {2, 4});
|
||||||
|
// NDArray y = NDArrayFactory::create<float>('c', {1, 2}, {1,2});
|
||||||
|
NDArray e = NDArrayFactory::create<Nd4jLong>({1, 2, 1});
|
||||||
|
|
||||||
|
nd4j::ops::shapes_of op;
|
||||||
|
auto result = op.execute({&x}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
// z->printIndexedBuffer("OUtput RealDiv");
|
||||||
|
// ASSERT_TRUE(e.isSameShape(z));
|
||||||
|
ASSERT_TRUE(e.equalsTo(*z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests7, ShapesOf_2) {
|
||||||
|
|
||||||
|
NDArray x = NDArrayFactory::create<float>('c', {1, 2, 1}, {2, 4});
|
||||||
|
NDArray y = NDArrayFactory::create<float>('c', {1, 2}, {1,2});
|
||||||
|
NDArray e0 = NDArrayFactory::create<Nd4jLong>({1, 2, 1});
|
||||||
|
NDArray e1 = NDArrayFactory::create<Nd4jLong>({1, 2});
|
||||||
|
|
||||||
|
nd4j::ops::shapes_of op;
|
||||||
|
auto result = op.execute({&x, &y}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z0 = result->at(0);
|
||||||
|
auto z1 = result->at(1);
|
||||||
|
// z0->printIndexedBuffer("OUtput shapes2");
|
||||||
|
// z1->printIndexedBuffer("OUtput shapes2");
|
||||||
|
// ASSERT_TRUE(e.isSameShape(z));
|
||||||
|
ASSERT_TRUE(e0.equalsTo(z0));
|
||||||
|
ASSERT_TRUE(e1.equalsTo(z1));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests7, Size_1) {
|
||||||
|
|
||||||
|
NDArray x = NDArrayFactory::create<float>('c', {1, 2, 1}, {2, 4});
|
||||||
|
NDArray y = NDArrayFactory::create<float>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
|
||||||
|
NDArray e = NDArrayFactory::create<Nd4jLong>(2);
|
||||||
|
|
||||||
|
nd4j::ops::size op;
|
||||||
|
auto result = op.execute({&x}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
// z->printIndexedBuffer("OUtput SIZE");
|
||||||
|
/// ASSERT_TRUE(e.isSameShape(z));
|
||||||
|
ASSERT_TRUE(e.equalsTo(*z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests7, Size_2) {
|
||||||
|
|
||||||
|
NDArray x = NDArrayFactory::create<float>('c', {1, 2, 1}, {2, 4});
|
||||||
|
NDArray y = NDArrayFactory::create<float>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
|
||||||
|
NDArray e = NDArrayFactory::create<Nd4jLong>(10);
|
||||||
|
|
||||||
|
nd4j::ops::size op;
|
||||||
|
auto result = op.execute({&y}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
// z->printIndexedBuffer("OUtput SIZE");
|
||||||
|
/// ASSERT_TRUE(e.isSameShape(z));
|
||||||
|
ASSERT_TRUE(e.equalsTo(*z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests7, Softplus_1) {
|
||||||
|
|
||||||
|
NDArray x = NDArrayFactory::create<float>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
|
||||||
|
NDArray e = NDArrayFactory::create<float>('c', {5, 2}, {1.3132616, 2.126928, 3.0485873, 4.01815, 5.0067153, 7.0009117, 9.000123, 10.000046, 10.000046, 11.000016});
|
||||||
|
|
||||||
|
nd4j::ops::softplus op;
|
||||||
|
auto result = op.execute({&x}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
// z->printIndexedBuffer("OUtput Softplus");
|
||||||
|
/// ASSERT_TRUE(e.isSameShape(z));
|
||||||
|
ASSERT_TRUE(e.equalsTo(*z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests7, Softplus_BP_1) {
|
||||||
|
|
||||||
|
NDArray x = NDArrayFactory::create<double >('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
|
||||||
|
// NDArray e = NDArrayFactory::create<float>('c', {5, 2}, {1.3132616, 2.126928, 3.0485873, 4.01815, 5.0067153, 7.0009117, 9.000123, 10.000046, 10.000046, 11.000016});
|
||||||
|
NDArray eps = NDArrayFactory::create<double>('c', {5, 2}, {1,2,3,4,5,6,7,8, 9, 10});
|
||||||
|
nd4j::ops::softplus ffOP;
|
||||||
|
nd4j::ops::softplus_bp bpOp;
|
||||||
|
const OpArgsHolder argsHolderFF({&x}, {}, {});
|
||||||
|
const OpArgsHolder argsHolderBP({&x, &eps}, {}, {});
|
||||||
|
|
||||||
|
bool gradOK = GradCheck::checkGrad(ffOP, bpOp, argsHolderFF, argsHolderBP);
|
||||||
|
|
||||||
|
ASSERT_TRUE(gradOK);
|
||||||
|
//
|
||||||
|
// auto z = result->at(0);
|
||||||
|
// z->printIndexedBuffer("OUtput Softplus");
|
||||||
|
///// ASSERT_TRUE(e.isSameShape(z));
|
||||||
|
// ASSERT_TRUE(e.equalsTo(*z));
|
||||||
|
//
|
||||||
|
// delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests7, Softsign_1) {
|
||||||
|
|
||||||
|
NDArray x = NDArrayFactory::create<float>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
|
||||||
|
NDArray e = NDArrayFactory::create<float>('c', {5, 2}, {0.5, 0.6666667, 0.75, 0.8, 0.8333333, 0.875, 0.9, 0.90909094, 0.90909094, 0.9166667});
|
||||||
|
|
||||||
|
nd4j::ops::softsign op;
|
||||||
|
auto result = op.execute({&x}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
// z->printIndexedBuffer("OUtput Softsign");
|
||||||
|
/// ASSERT_TRUE(e.isSameShape(z));
|
||||||
|
ASSERT_TRUE(e.equalsTo(*z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests7, Softsign_BP_1) {
|
||||||
|
|
||||||
|
NDArray x = NDArrayFactory::create<double >('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
|
||||||
|
// NDArray e = NDArrayFactory::create<float>('c', {5, 2}, {1.3132616, 2.126928, 3.0485873, 4.01815, 5.0067153, 7.0009117, 9.000123, 10.000046, 10.000046, 11.000016});
|
||||||
|
NDArray eps = NDArrayFactory::create<double>('c', {5, 2}, {1,2,3,4,5,6,7,8, 9, 10});
|
||||||
|
nd4j::ops::softsign ffOP;
|
||||||
|
nd4j::ops::softsign_bp bpOp;
|
||||||
|
const OpArgsHolder argsHolderFF({&x}, {}, {});
|
||||||
|
const OpArgsHolder argsHolderBP({&x, &eps}, {}, {});
|
||||||
|
|
||||||
|
bool gradOK = GradCheck::checkGrad(ffOP, bpOp, argsHolderFF, argsHolderBP);
|
||||||
|
|
||||||
|
ASSERT_TRUE(gradOK);
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests7, fill_test2) {
|
TEST_F(DeclarableOpsTests7, fill_test2) {
|
||||||
|
|
||||||
|
@ -3644,6 +3928,185 @@ TEST_F(DeclarableOpsTests7, fill_test3) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests7, ToggleBits_test1) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<int>('c', {2}, {2, 2});
|
||||||
|
auto exp = NDArrayFactory::create<int>('c', {2}, {-3, -3});
|
||||||
|
|
||||||
|
nd4j::ops::toggle_bits op;
|
||||||
|
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::INT32);
|
||||||
|
auto output = result->at(0);
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
// output->printIndexedBuffer("Toggled");
|
||||||
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests7, ToggleBits_test2) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<int>('c', {2}, {2, 2});
|
||||||
|
auto y = NDArrayFactory::create<int>('c', {2}, {1, 1});
|
||||||
|
auto exp0 = NDArrayFactory::create<int>('c', {2}, {-3, -3});
|
||||||
|
auto exp1 = NDArrayFactory::create<int>('c', {2}, {-2, -2});
|
||||||
|
|
||||||
|
nd4j::ops::toggle_bits op;
|
||||||
|
auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT32);
|
||||||
|
auto output = result->at(0);
|
||||||
|
auto z = result->at(1);
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
// output->printIndexedBuffer("Toggled");
|
||||||
|
ASSERT_TRUE(exp0.isSameShape(output));
|
||||||
|
ASSERT_TRUE(exp0.equalsTo(output));
|
||||||
|
ASSERT_TRUE(exp1.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp1.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests7, Truncatediv_test1) {
|
||||||
|
NDArray x = NDArrayFactory::create<double >('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
|
||||||
|
NDArray y = NDArrayFactory::create<double >('c', {5, 2}, {2,2,2,2,2,2,2,2, 2, 2});
|
||||||
|
NDArray exp = NDArrayFactory::create<double >('c', {5, 2}, {0.5, 1., 1.5, 2., 2.5, 3.5, 4.5, 5., 5., 5.5});
|
||||||
|
|
||||||
|
nd4j::ops::truncatediv op;
|
||||||
|
auto result = op.execute({&x, &y}, {}, {});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
auto output = result->at(0);
|
||||||
|
// output->printIndexedBuffer("Toggled");
|
||||||
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests7, Truncatediv_test2) {
|
||||||
|
NDArray x = NDArrayFactory::create<double >('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
|
||||||
|
NDArray y = NDArrayFactory::create<double >('c', {1, 2}, {2,2});
|
||||||
|
NDArray exp = NDArrayFactory::create<double >('c', {5, 2}, {0.5, 1., 1.5, 2., 2.5, 3.5, 4.5, 5., 5., 5.5});
|
||||||
|
|
||||||
|
nd4j::ops::truncatediv op;
|
||||||
|
auto result = op.execute({&x, &y}, {}, {});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
auto output = result->at(0);
|
||||||
|
// output->printIndexedBuffer("Toggled");
|
||||||
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests7, TypesConversion_test1) {
|
||||||
|
NDArray x = NDArrayFactory::create<double >('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
|
||||||
|
NDArray expI = NDArrayFactory::create<int>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
|
||||||
|
NDArray expL = NDArrayFactory::create<Nd4jLong>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
|
||||||
|
NDArray expF = NDArrayFactory::create<float>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
|
||||||
|
NDArray expF16 = NDArrayFactory::create<float16>('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f});
|
||||||
|
|
||||||
|
nd4j::ops::to_int32 op32;
|
||||||
|
nd4j::ops::to_int64 op64;
|
||||||
|
auto result32 = op32.execute({&x}, {}, {});
|
||||||
|
auto result64 = op64.execute({&x}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result32->status());
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result64->status());
|
||||||
|
auto out1 = result32->at(0);
|
||||||
|
// out1->printIndexedBuffer("OUT_I");
|
||||||
|
auto out2 = result64->at(0);
|
||||||
|
// out2->printIndexedBuffer("OUT_L");
|
||||||
|
|
||||||
|
// output->printIndexedBuffer("Toggled");
|
||||||
|
ASSERT_TRUE(expI.equalsTo(out1));
|
||||||
|
ASSERT_TRUE(expL.equalsTo(out2));
|
||||||
|
|
||||||
|
delete result32;
|
||||||
|
delete result64;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests7, TypesConversion_test2) {
|
||||||
|
NDArray x = NDArrayFactory::create<double >('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
|
||||||
|
NDArray expF = NDArrayFactory::create<float>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
|
||||||
|
NDArray expH = NDArrayFactory::create<float16>('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f});
|
||||||
|
|
||||||
|
nd4j::ops::to_float32 op32;
|
||||||
|
nd4j::ops::to_float16 op16;
|
||||||
|
auto result32 = op32.execute({&x}, {}, {});
|
||||||
|
auto result16 = op16.execute({&x}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result32->status());
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result16->status());
|
||||||
|
auto out1 = result32->at(0);
|
||||||
|
// out1->printIndexedBuffer("OUT_F");
|
||||||
|
auto out2 = result16->at(0);
|
||||||
|
// out2->printIndexedBuffer("OUT_H");
|
||||||
|
|
||||||
|
// output->printIndexedBuffer("Toggled");
|
||||||
|
ASSERT_TRUE(expF.equalsTo(out1));
|
||||||
|
ASSERT_TRUE(expH.equalsTo(out2));
|
||||||
|
|
||||||
|
delete result32;
|
||||||
|
delete result16;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests7, TypesConversion_test3) {
|
||||||
|
NDArray x = NDArrayFactory::create<Nd4jLong>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
|
||||||
|
NDArray exp32 = NDArrayFactory::create<unsigned int>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
|
||||||
|
NDArray exp64 = NDArrayFactory::create<uint64_t>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
|
||||||
|
|
||||||
|
nd4j::ops::to_uint32 op32;
|
||||||
|
nd4j::ops::to_uint64 op64;
|
||||||
|
auto result32 = op32.execute({&x}, {}, {});
|
||||||
|
auto result64 = op64.execute({&x}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result32->status());
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result64->status());
|
||||||
|
auto out1 = result32->at(0);
|
||||||
|
// out1->printIndexedBuffer("OUT_U32");
|
||||||
|
auto out2 = result64->at(0);
|
||||||
|
// out2->printIndexedBuffer("OUT_U64");
|
||||||
|
|
||||||
|
// output->printIndexedBuffer("Toggled");
|
||||||
|
ASSERT_TRUE(exp32.equalsTo(out1));
|
||||||
|
ASSERT_TRUE(exp64.equalsTo(out2));
|
||||||
|
|
||||||
|
delete result32;
|
||||||
|
delete result64;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests7, TypesConversion_test4) {
|
||||||
|
NDArray x = NDArrayFactory::create<Nd4jLong>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
|
||||||
|
NDArray exp32 = NDArrayFactory::create<float>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
|
||||||
|
NDArray exp64 = NDArrayFactory::create<double>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
|
||||||
|
|
||||||
|
nd4j::ops::to_float32 op32;
|
||||||
|
nd4j::ops::to_double op64;
|
||||||
|
auto result32 = op32.execute({&x}, {}, {});
|
||||||
|
auto result64 = op64.execute({&x}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result32->status());
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result64->status());
|
||||||
|
auto out1 = result32->at(0);
|
||||||
|
out1->printIndexedBuffer("OUT_F");
|
||||||
|
auto out2 = result64->at(0);
|
||||||
|
out2->printIndexedBuffer("OUT_D");
|
||||||
|
|
||||||
|
// output->printIndexedBuffer("Toggled");
|
||||||
|
ASSERT_TRUE(exp32.equalsTo(out1));
|
||||||
|
ASSERT_TRUE(exp64.equalsTo(out2));
|
||||||
|
|
||||||
|
delete result32;
|
||||||
|
delete result64;
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests7, mirrorPad_test1) {
|
TEST_F(DeclarableOpsTests7, mirrorPad_test1) {
|
||||||
|
|
||||||
|
|
|
@ -78,6 +78,71 @@ TEST_F(ListOperationsTests, BasicTest_Stack_1) {
|
||||||
delete tads;
|
delete tads;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ListOperationsTests, BasicTest_UnStackList_1) {
|
||||||
|
NDArrayList list(0, true);
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {10, 100});
|
||||||
|
auto tads = x.allTensorsAlongDimension({1});
|
||||||
|
for (int e = 0; e < 10; e++) {
|
||||||
|
auto row = NDArrayFactory::create_<double>('c', {100});
|
||||||
|
row->assign((double) e);
|
||||||
|
//list.write(e, row);
|
||||||
|
tads->at(e)->assign(row);
|
||||||
|
delete row;
|
||||||
|
}
|
||||||
|
|
||||||
|
nd4j::ops::unstack_list op;
|
||||||
|
|
||||||
|
auto result = op.execute(&list, {&x}, {}, {0});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
ASSERT_EQ(list.elements(), 10);
|
||||||
|
|
||||||
|
// auto z = result->at(0);
|
||||||
|
// z->printShapeInfo("The first of");
|
||||||
|
// ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
// ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
for (int e = 0; e < 10; e++) {
|
||||||
|
auto row = list.read(e);
|
||||||
|
ASSERT_TRUE(row->equalsTo(tads->at(e)));
|
||||||
|
//list.write(e, row);
|
||||||
|
}
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
delete tads;
|
||||||
|
}
|
||||||
|
|
||||||
|
//TEST_F(ListOperationsTests, BasicTest_UnStackList_2) {
|
||||||
|
//// NDArrayList list(0, true);
|
||||||
|
// auto x = NDArrayFactory::create<double>('c', {10, 100});
|
||||||
|
// auto tads = x.allTensorsAlongDimension({1});
|
||||||
|
// for (int e = 0; e < 10; e++) {
|
||||||
|
// auto row = NDArrayFactory::create_<double>('c', {100});
|
||||||
|
// row->assign((double) e);
|
||||||
|
// //list.write(e, row);
|
||||||
|
// tads->at(e)->assign(row);
|
||||||
|
// delete row;
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// nd4j::ops::unstack_list op;
|
||||||
|
//
|
||||||
|
// auto result = op.execute(nullptr, {&x}, {}, {0});
|
||||||
|
//
|
||||||
|
// ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
// ASSERT_EQ(result->size(), 10);
|
||||||
|
//
|
||||||
|
// // auto z = result->at(0);
|
||||||
|
//// z->printShapeInfo("The first of");
|
||||||
|
//// ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
//// ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
// for (int e = 0; e < 10; e++) {
|
||||||
|
// auto row = result->at(e);
|
||||||
|
// ASSERT_TRUE(row->equalsTo(tads->at(e)));
|
||||||
|
// //list.write(e, row);
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// delete result;
|
||||||
|
// delete tads;
|
||||||
|
//}
|
||||||
|
|
||||||
TEST_F(ListOperationsTests, BasicTest_Read_1) {
|
TEST_F(ListOperationsTests, BasicTest_Read_1) {
|
||||||
NDArrayList list(10);
|
NDArrayList list(10);
|
||||||
|
|
Loading…
Reference in New Issue