[WIP] few tweaks (#206)
* scatter empty check Signed-off-by: raver119 <raver119@gmail.com> * scatter empty test Signed-off-by: raver119 <raver119@gmail.com> * one more test Signed-off-by: raver119 <raver119@gmail.com> * two tweaks Signed-off-by: raver119 <raver119@gmail.com> * dup tweak Signed-off-by: raver119 <raver119@gmail.com> * - put empty checking of indices array immediately prior helper run Signed-off-by: Yurii <yurii@skymind.io> * minor tests fix Signed-off-by: raver119 <raver119@gmail.com> * minor tests fix Signed-off-by: raver119 <raver119@gmail.com>master
parent
f414239ed5
commit
70a9ae5068
|
@ -35,6 +35,9 @@ OP_IMPL(scatter_add, 3, 1, true) {
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
if (!block.isInplace())
|
||||||
|
output->assign(input);
|
||||||
|
|
||||||
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
|
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
|
||||||
|
|
||||||
const int inRank = input->rankOf();
|
const int inRank = input->rankOf();
|
||||||
|
@ -68,10 +71,8 @@ OP_IMPL(scatter_add, 3, 1, true) {
|
||||||
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ADD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ADD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!block.isInplace())
|
if (!indices->isEmpty())
|
||||||
output->assign(input);
|
helpers::scatter(block.launchContext(), pairwise::Add, *indices, *updates, *output, lock);
|
||||||
|
|
||||||
helpers::scatter(block.launchContext(), pairwise::Add, *indices, *updates, *output, lock);
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author Created by raver119 on 24.11.17.
|
// @author Created by raver119 on 24.11.17.
|
||||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <op_boilerplate.h>
|
#include <op_boilerplate.h>
|
||||||
|
@ -28,21 +28,24 @@
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
OP_IMPL(scatter_div, 3, 1, true) {
|
OP_IMPL(scatter_div, 3, 1, true) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto indices = INPUT_VARIABLE(1);
|
auto indices = INPUT_VARIABLE(1);
|
||||||
auto updates = INPUT_VARIABLE(2);
|
auto updates = INPUT_VARIABLE(2);
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
if (!block.isInplace())
|
||||||
|
output->assign(input);
|
||||||
|
|
||||||
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
|
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
|
||||||
|
|
||||||
const int inRank = input->rankOf();
|
const int inRank = input->rankOf();
|
||||||
const int indRank = indices->rankOf();
|
const int indRank = indices->rankOf();
|
||||||
const int updRank = updates->rankOf();
|
const int updRank = updates->rankOf();
|
||||||
|
|
||||||
REQUIRE_TRUE(inRank > 0, 0, "SCATTER_DIV OP: input should not be scalar !");
|
REQUIRE_TRUE(inRank > 0, 0, "SCATTER_DIV OP: input should not be scalar !");
|
||||||
|
|
||||||
if(inRank == 1) {
|
if(inRank == 1) {
|
||||||
REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_DIV OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str());
|
REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_DIV OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str());
|
||||||
}
|
}
|
||||||
|
@ -50,28 +53,27 @@ namespace nd4j {
|
||||||
|
|
||||||
std::vector<Nd4jLong> updShape = updates->getShapeAsVector();
|
std::vector<Nd4jLong> updShape = updates->getShapeAsVector();
|
||||||
std::vector<Nd4jLong> inShape = input->getShapeAsVector();
|
std::vector<Nd4jLong> inShape = input->getShapeAsVector();
|
||||||
std::vector<Nd4jLong> expectedUpdShape = {indices->lengthOf()};
|
std::vector<Nd4jLong> expectedUpdShape = {indices->lengthOf()};
|
||||||
expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end());
|
expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end());
|
||||||
|
|
||||||
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_DIV OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_DIV OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_DIV OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank);
|
REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_DIV OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank);
|
||||||
|
|
||||||
std::vector<Nd4jLong> updShape = updates->getShapeAsVector();
|
std::vector<Nd4jLong> updShape = updates->getShapeAsVector();
|
||||||
std::vector<Nd4jLong> inShape = input->getShapeAsVector();
|
std::vector<Nd4jLong> inShape = input->getShapeAsVector();
|
||||||
std::vector<Nd4jLong> expectedUpdShape = indices->getShapeAsVector();
|
std::vector<Nd4jLong> expectedUpdShape = indices->getShapeAsVector();
|
||||||
expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end());
|
expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end());
|
||||||
|
|
||||||
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_DIV OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_DIV OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!block.isInplace())
|
if (!indices->isEmpty())
|
||||||
output->assign(input);
|
// ScatterHelper<T>::template scatterApply<simdOps::Divide<T>>(output, indices, updates);
|
||||||
|
helpers::scatter(block.launchContext(), pairwise::Divide, *indices, *updates, *output, lock);
|
||||||
|
|
||||||
// ScatterHelper<T>::template scatterApply<simdOps::Divide<T>>(output, indices, updates);
|
|
||||||
helpers::scatter(block.launchContext(), pairwise::Divide, *indices, *updates, *output, lock);
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
DECLARE_SYN(ScatterDiv, scatter_div);
|
DECLARE_SYN(ScatterDiv, scatter_div);
|
||||||
|
|
|
@ -35,14 +35,17 @@ OP_IMPL(scatter_max, 3, 1, true) {
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
if (!block.isInplace())
|
||||||
|
output->assign(input);
|
||||||
|
|
||||||
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
|
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
|
||||||
|
|
||||||
const int inRank = input->rankOf();
|
const int inRank = input->rankOf();
|
||||||
const int indRank = indices->rankOf();
|
const int indRank = indices->rankOf();
|
||||||
const int updRank = updates->rankOf();
|
const int updRank = updates->rankOf();
|
||||||
|
|
||||||
REQUIRE_TRUE(inRank > 0, 0, "SCATTER_MAX OP: input should not be scalar !");
|
REQUIRE_TRUE(inRank > 0, 0, "SCATTER_MAX OP: input should not be scalar !");
|
||||||
|
|
||||||
if(inRank == 1) {
|
if(inRank == 1) {
|
||||||
REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_MAX OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str());
|
REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_MAX OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str());
|
||||||
}
|
}
|
||||||
|
@ -50,28 +53,26 @@ OP_IMPL(scatter_max, 3, 1, true) {
|
||||||
|
|
||||||
std::vector<Nd4jLong> updShape = updates->getShapeAsVector();
|
std::vector<Nd4jLong> updShape = updates->getShapeAsVector();
|
||||||
std::vector<Nd4jLong> inShape = input->getShapeAsVector();
|
std::vector<Nd4jLong> inShape = input->getShapeAsVector();
|
||||||
std::vector<Nd4jLong> expectedUpdShape = {indices->lengthOf()};
|
std::vector<Nd4jLong> expectedUpdShape = {indices->lengthOf()};
|
||||||
expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end());
|
expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end());
|
||||||
|
|
||||||
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MAX OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MAX OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_MAX OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank);
|
REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_MAX OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank);
|
||||||
|
|
||||||
std::vector<Nd4jLong> updShape = updates->getShapeAsVector();
|
std::vector<Nd4jLong> updShape = updates->getShapeAsVector();
|
||||||
std::vector<Nd4jLong> inShape = input->getShapeAsVector();
|
std::vector<Nd4jLong> inShape = input->getShapeAsVector();
|
||||||
std::vector<Nd4jLong> expectedUpdShape = indices->getShapeAsVector();
|
std::vector<Nd4jLong> expectedUpdShape = indices->getShapeAsVector();
|
||||||
expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end());
|
expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end());
|
||||||
|
|
||||||
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MAX OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MAX OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!block.isInplace())
|
if (!indices->isEmpty())
|
||||||
output->assign(input);
|
helpers::scatter(block.launchContext(), pairwise::MaxPairwise, *indices, *updates, *output, lock);
|
||||||
|
|
||||||
helpers::scatter(block.launchContext(), pairwise::MaxPairwise, *indices, *updates, *output, lock);
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
DECLARE_SYN(ScatterMax, scatter_max);
|
DECLARE_SYN(ScatterMax, scatter_max);
|
||||||
|
|
|
@ -35,14 +35,17 @@ OP_IMPL(scatter_min, 3, 1, true) {
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
if (!block.isInplace())
|
||||||
|
output->assign(input);
|
||||||
|
|
||||||
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
|
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
|
||||||
|
|
||||||
const int inRank = input->rankOf();
|
const int inRank = input->rankOf();
|
||||||
const int indRank = indices->rankOf();
|
const int indRank = indices->rankOf();
|
||||||
const int updRank = updates->rankOf();
|
const int updRank = updates->rankOf();
|
||||||
|
|
||||||
REQUIRE_TRUE(inRank > 0, 0, "SCATTER_MIN OP: input should not be scalar !");
|
REQUIRE_TRUE(inRank > 0, 0, "SCATTER_MIN OP: input should not be scalar !");
|
||||||
|
|
||||||
if(inRank == 1) {
|
if(inRank == 1) {
|
||||||
REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_MIN OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str());
|
REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_MIN OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str());
|
||||||
}
|
}
|
||||||
|
@ -50,27 +53,25 @@ OP_IMPL(scatter_min, 3, 1, true) {
|
||||||
|
|
||||||
std::vector<Nd4jLong> updShape = updates->getShapeAsVector();
|
std::vector<Nd4jLong> updShape = updates->getShapeAsVector();
|
||||||
std::vector<Nd4jLong> inShape = input->getShapeAsVector();
|
std::vector<Nd4jLong> inShape = input->getShapeAsVector();
|
||||||
std::vector<Nd4jLong> expectedUpdShape = {indices->lengthOf()};
|
std::vector<Nd4jLong> expectedUpdShape = {indices->lengthOf()};
|
||||||
expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end());
|
expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end());
|
||||||
|
|
||||||
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MIN OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MIN OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_MIN OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank);
|
REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_MIN OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank);
|
||||||
|
|
||||||
std::vector<Nd4jLong> updShape = updates->getShapeAsVector();
|
std::vector<Nd4jLong> updShape = updates->getShapeAsVector();
|
||||||
std::vector<Nd4jLong> inShape = input->getShapeAsVector();
|
std::vector<Nd4jLong> inShape = input->getShapeAsVector();
|
||||||
std::vector<Nd4jLong> expectedUpdShape = indices->getShapeAsVector();
|
std::vector<Nd4jLong> expectedUpdShape = indices->getShapeAsVector();
|
||||||
expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end());
|
expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end());
|
||||||
|
|
||||||
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MIN OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MIN OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!block.isInplace())
|
if (!indices->isEmpty())
|
||||||
output->assign(input);
|
helpers::scatter(block.launchContext(), pairwise::MinPairwise, *indices, *updates, *output, lock);
|
||||||
|
|
||||||
helpers::scatter(block.launchContext(), pairwise::MinPairwise, *indices, *updates, *output, lock);
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -39,9 +39,13 @@ namespace nd4j {
|
||||||
const int inRank = input->rankOf();
|
const int inRank = input->rankOf();
|
||||||
const int indRank = indices->rankOf();
|
const int indRank = indices->rankOf();
|
||||||
const int updRank = updates->rankOf();
|
const int updRank = updates->rankOf();
|
||||||
|
|
||||||
|
if (!block.isInplace())
|
||||||
|
output->assign(input);
|
||||||
|
|
||||||
|
|
||||||
REQUIRE_TRUE(inRank > 0, 0, "SCATTER_MUL OP: input should not be scalar !");
|
REQUIRE_TRUE(inRank > 0, 0, "SCATTER_MUL OP: input should not be scalar !");
|
||||||
|
|
||||||
if(inRank == 1) {
|
if(inRank == 1) {
|
||||||
REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_MUL OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str());
|
REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_MUL OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str());
|
||||||
}
|
}
|
||||||
|
@ -49,27 +53,25 @@ namespace nd4j {
|
||||||
|
|
||||||
std::vector<Nd4jLong> updShape = updates->getShapeAsVector();
|
std::vector<Nd4jLong> updShape = updates->getShapeAsVector();
|
||||||
std::vector<Nd4jLong> inShape = input->getShapeAsVector();
|
std::vector<Nd4jLong> inShape = input->getShapeAsVector();
|
||||||
std::vector<Nd4jLong> expectedUpdShape = {indices->lengthOf()};
|
std::vector<Nd4jLong> expectedUpdShape = {indices->lengthOf()};
|
||||||
expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end());
|
expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end());
|
||||||
|
|
||||||
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MUL OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MUL OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_MUL OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank);
|
REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_MUL OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank);
|
||||||
|
|
||||||
std::vector<Nd4jLong> updShape = updates->getShapeAsVector();
|
std::vector<Nd4jLong> updShape = updates->getShapeAsVector();
|
||||||
std::vector<Nd4jLong> inShape = input->getShapeAsVector();
|
std::vector<Nd4jLong> inShape = input->getShapeAsVector();
|
||||||
std::vector<Nd4jLong> expectedUpdShape = indices->getShapeAsVector();
|
std::vector<Nd4jLong> expectedUpdShape = indices->getShapeAsVector();
|
||||||
expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end());
|
expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end());
|
||||||
|
|
||||||
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MUL OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MUL OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!block.isInplace())
|
if (!indices->isEmpty())
|
||||||
output->assign(input);
|
helpers::scatter(block.launchContext(), pairwise::Multiply, *indices, *updates, *output, lock);
|
||||||
|
|
||||||
helpers::scatter(block.launchContext(), pairwise::Multiply, *indices, *updates, *output, lock);
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,14 +34,17 @@ namespace nd4j {
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
if (!block.isInplace())
|
||||||
|
output->assign(input);
|
||||||
|
|
||||||
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
|
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
|
||||||
|
|
||||||
const int inRank = input->rankOf();
|
const int inRank = input->rankOf();
|
||||||
const int indRank = indices->rankOf();
|
const int indRank = indices->rankOf();
|
||||||
const int updRank = updates->rankOf();
|
const int updRank = updates->rankOf();
|
||||||
|
|
||||||
REQUIRE_TRUE(inRank > 0, 0, "SCATTER_SUB OP: input should not be scalar !");
|
REQUIRE_TRUE(inRank > 0, 0, "SCATTER_SUB OP: input should not be scalar !");
|
||||||
|
|
||||||
if(inRank == 1) {
|
if(inRank == 1) {
|
||||||
REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_SUB OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str());
|
REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_SUB OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str());
|
||||||
}
|
}
|
||||||
|
@ -49,29 +52,27 @@ namespace nd4j {
|
||||||
|
|
||||||
std::vector<Nd4jLong> updShape = updates->getShapeAsVector();
|
std::vector<Nd4jLong> updShape = updates->getShapeAsVector();
|
||||||
std::vector<Nd4jLong> inShape = input->getShapeAsVector();
|
std::vector<Nd4jLong> inShape = input->getShapeAsVector();
|
||||||
std::vector<Nd4jLong> expectedUpdShape = {indices->lengthOf()};
|
std::vector<Nd4jLong> expectedUpdShape = {indices->lengthOf()};
|
||||||
expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end());
|
expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end());
|
||||||
|
|
||||||
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_SUB OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_SUB OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
else {
|
else {
|
||||||
|
|
||||||
REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_SUB OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank);
|
REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_SUB OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank);
|
||||||
|
|
||||||
std::vector<Nd4jLong> updShape = updates->getShapeAsVector();
|
std::vector<Nd4jLong> updShape = updates->getShapeAsVector();
|
||||||
std::vector<Nd4jLong> inShape = input->getShapeAsVector();
|
std::vector<Nd4jLong> inShape = input->getShapeAsVector();
|
||||||
std::vector<Nd4jLong> expectedUpdShape = indices->getShapeAsVector();
|
std::vector<Nd4jLong> expectedUpdShape = indices->getShapeAsVector();
|
||||||
expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end());
|
expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end());
|
||||||
|
|
||||||
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_SUB OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_SUB OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!block.isInplace())
|
if (!indices->isEmpty())
|
||||||
output->assign(input);
|
// ScatterHelper<T>::template scatterApply<simdOps::Subtract<T>>(output, indices, updates);
|
||||||
|
helpers::scatter(block.launchContext(), pairwise::Subtract, *indices, *updates, *output, lock);
|
||||||
// ScatterHelper<T>::template scatterApply<simdOps::Subtract<T>>(output, indices, updates);
|
|
||||||
helpers::scatter(block.launchContext(), pairwise::Subtract, *indices, *updates, *output, lock);
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,14 +33,17 @@ namespace nd4j {
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
if (!block.isInplace())
|
||||||
|
output->assign(input);
|
||||||
|
|
||||||
const bool lock = block.getBArguments()->empty() ? true : B_ARG(0);
|
const bool lock = block.getBArguments()->empty() ? true : B_ARG(0);
|
||||||
|
|
||||||
const int inRank = input->rankOf();
|
const int inRank = input->rankOf();
|
||||||
const int indRank = indices->rankOf();
|
const int indRank = indices->rankOf();
|
||||||
const int updRank = updates->rankOf();
|
const int updRank = updates->rankOf();
|
||||||
|
|
||||||
REQUIRE_TRUE(inRank > 0, 0, "SCATTER_UPD OP: input should not be scalar !");
|
REQUIRE_TRUE(inRank > 0, 0, "SCATTER_UPD OP: input should not be scalar !");
|
||||||
|
|
||||||
if(inRank == 1) {
|
if(inRank == 1) {
|
||||||
REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_UPD OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str());
|
REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_UPD OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str());
|
||||||
}
|
}
|
||||||
|
@ -48,28 +51,26 @@ namespace nd4j {
|
||||||
|
|
||||||
std::vector<Nd4jLong> updShape = updates->getShapeAsVector();
|
std::vector<Nd4jLong> updShape = updates->getShapeAsVector();
|
||||||
std::vector<Nd4jLong> inShape = input->getShapeAsVector();
|
std::vector<Nd4jLong> inShape = input->getShapeAsVector();
|
||||||
std::vector<Nd4jLong> expectedUpdShape = {indices->lengthOf()};
|
std::vector<Nd4jLong> expectedUpdShape = {indices->lengthOf()};
|
||||||
expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end());
|
expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end());
|
||||||
|
|
||||||
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_UPD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_UPD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_UPD OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank);
|
REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_UPD OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank);
|
||||||
|
|
||||||
std::vector<Nd4jLong> updShape = updates->getShapeAsVector();
|
std::vector<Nd4jLong> updShape = updates->getShapeAsVector();
|
||||||
std::vector<Nd4jLong> inShape = input->getShapeAsVector();
|
std::vector<Nd4jLong> inShape = input->getShapeAsVector();
|
||||||
std::vector<Nd4jLong> expectedUpdShape = indices->getShapeAsVector();
|
std::vector<Nd4jLong> expectedUpdShape = indices->getShapeAsVector();
|
||||||
expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end());
|
expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end());
|
||||||
|
|
||||||
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_UPD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_UPD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!block.isInplace())
|
if (!indices->isEmpty())
|
||||||
output->assign(input);
|
// ScatterHelper<T>::template scatterApply<simdOps::Copy<T>>(output, indices, updates);
|
||||||
|
helpers::scatter(block.launchContext(), pairwise::CopyPws, *indices, *updates, *output, lock);
|
||||||
// ScatterHelper<T>::template scatterApply<simdOps::Copy<T>>(output, indices, updates);
|
|
||||||
helpers::scatter(block.launchContext(), pairwise::CopyPws, *indices, *updates, *output, lock);
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -207,6 +207,46 @@ TEST_F(EmptyTests, Test_dup_1) {
|
||||||
delete dup;
|
delete dup;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(EmptyTests, test_empty_scatter_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {5});
|
||||||
|
auto indices = NDArrayFactory::create<int>('c', {0});
|
||||||
|
auto updates = NDArrayFactory::create<float>('c', {0});
|
||||||
|
|
||||||
|
x.linspace(1.0f);
|
||||||
|
|
||||||
|
nd4j::ops::scatter_upd op;
|
||||||
|
auto result = op.execute({&x, &indices, &updates}, {}, {}, {true});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
ASSERT_EQ(x, *z);
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(EmptyTests, test_empty_scatter_2) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {5});
|
||||||
|
auto z = NDArrayFactory::create<float>('c', {5});
|
||||||
|
auto indices = NDArrayFactory::create<int>('c', {0});
|
||||||
|
auto updates = NDArrayFactory::create<float>('c', {0});
|
||||||
|
|
||||||
|
x.linspace(1.0f);
|
||||||
|
|
||||||
|
Context ctx(1);
|
||||||
|
ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
|
||||||
|
ctx.setInputArray(1, indices.buffer(), indices.shapeInfo(), indices.specialBuffer(), indices.specialShapeInfo());
|
||||||
|
ctx.setInputArray(2, updates.buffer(), updates.shapeInfo(), updates.specialBuffer(), updates.specialShapeInfo());
|
||||||
|
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
||||||
|
bool args[] = {true};
|
||||||
|
ctx.setBArguments(args, 1);
|
||||||
|
|
||||||
|
nd4j::ops::scatter_upd op;
|
||||||
|
auto result = op.execute(&ctx);
|
||||||
|
ASSERT_EQ(Status::OK(), result);
|
||||||
|
|
||||||
|
ASSERT_EQ(x, z);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(EmptyTests, test_shaped_empty_1) {
|
TEST_F(EmptyTests, test_shaped_empty_1) {
|
||||||
auto empty = NDArrayFactory::create<float>('c', {2, 0, 3});
|
auto empty = NDArrayFactory::create<float>('c', {2, 0, 3});
|
||||||
std::vector<Nd4jLong> shape = {2, 0, 3};
|
std::vector<Nd4jLong> shape = {2, 0, 3};
|
||||||
|
|
|
@ -1732,6 +1732,8 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
if(isEmpty())
|
if(isEmpty())
|
||||||
return this;
|
return this;
|
||||||
|
|
||||||
|
Nd4j.getCompressor().autoDecompress(this);
|
||||||
|
|
||||||
// fixme: eventually it would be nice to have this in native code
|
// fixme: eventually it would be nice to have this in native code
|
||||||
if (isS()) {
|
if (isS()) {
|
||||||
val list = new ArrayList<String>();
|
val list = new ArrayList<String>();
|
||||||
|
@ -1741,8 +1743,9 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
return Nd4j.create(list, this.shape(), this.ordering());
|
return Nd4j.create(list, this.shape(), this.ordering());
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4j.getCompressor().autoDecompress(this);
|
val z = Nd4j.createUninitialized(this.dataType(), this.shape(), order);
|
||||||
return Shape.toOffsetZeroCopy(this, order);
|
z.assign(this);
|
||||||
|
return z;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
Loading…
Reference in New Issue