[WIP] few tweaks (#206)

* scatter empty check

Signed-off-by: raver119 <raver119@gmail.com>

* scatter empty test

Signed-off-by: raver119 <raver119@gmail.com>

* one more test

Signed-off-by: raver119 <raver119@gmail.com>

* two tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* dup tweak

Signed-off-by: raver119 <raver119@gmail.com>

* - put empty checking of indices array immediately prior  helper run

Signed-off-by: Yurii <yurii@skymind.io>

* minor tests fix

Signed-off-by: raver119 <raver119@gmail.com>

* minor tests fix

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-30 16:32:01 +03:00 committed by GitHub
parent f414239ed5
commit 70a9ae5068
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 134 additions and 82 deletions

View File

@ -35,6 +35,9 @@ OP_IMPL(scatter_add, 3, 1, true) {
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
if (!block.isInplace())
output->assign(input);
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
const int inRank = input->rankOf(); const int inRank = input->rankOf();
@ -68,10 +71,8 @@ OP_IMPL(scatter_add, 3, 1, true) {
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ADD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ADD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
} }
if (!block.isInplace()) if (!indices->isEmpty())
output->assign(input); helpers::scatter(block.launchContext(), pairwise::Add, *indices, *updates, *output, lock);
helpers::scatter(block.launchContext(), pairwise::Add, *indices, *updates, *output, lock);
return Status::OK(); return Status::OK();
} }

View File

@ -35,6 +35,9 @@ namespace nd4j {
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
if (!block.isInplace())
output->assign(input);
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
const int inRank = input->rankOf(); const int inRank = input->rankOf();
@ -67,11 +70,10 @@ namespace nd4j {
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_DIV OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_DIV OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
} }
if (!block.isInplace()) if (!indices->isEmpty())
output->assign(input); // ScatterHelper<T>::template scatterApply<simdOps::Divide<T>>(output, indices, updates);
helpers::scatter(block.launchContext(), pairwise::Divide, *indices, *updates, *output, lock);
// ScatterHelper<T>::template scatterApply<simdOps::Divide<T>>(output, indices, updates);
helpers::scatter(block.launchContext(), pairwise::Divide, *indices, *updates, *output, lock);
return Status::OK(); return Status::OK();
} }
DECLARE_SYN(ScatterDiv, scatter_div); DECLARE_SYN(ScatterDiv, scatter_div);

View File

@ -35,6 +35,9 @@ OP_IMPL(scatter_max, 3, 1, true) {
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
if (!block.isInplace())
output->assign(input);
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
const int inRank = input->rankOf(); const int inRank = input->rankOf();
@ -67,10 +70,8 @@ OP_IMPL(scatter_max, 3, 1, true) {
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MAX OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MAX OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
} }
if (!block.isInplace()) if (!indices->isEmpty())
output->assign(input); helpers::scatter(block.launchContext(), pairwise::MaxPairwise, *indices, *updates, *output, lock);
helpers::scatter(block.launchContext(), pairwise::MaxPairwise, *indices, *updates, *output, lock);
return Status::OK(); return Status::OK();
} }

View File

@ -35,6 +35,9 @@ OP_IMPL(scatter_min, 3, 1, true) {
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
if (!block.isInplace())
output->assign(input);
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
const int inRank = input->rankOf(); const int inRank = input->rankOf();
@ -67,10 +70,8 @@ OP_IMPL(scatter_min, 3, 1, true) {
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MIN OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MIN OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
} }
if (!block.isInplace()) if (!indices->isEmpty())
output->assign(input); helpers::scatter(block.launchContext(), pairwise::MinPairwise, *indices, *updates, *output, lock);
helpers::scatter(block.launchContext(), pairwise::MinPairwise, *indices, *updates, *output, lock);
return Status::OK(); return Status::OK();
} }

View File

@ -40,6 +40,10 @@ namespace nd4j {
const int indRank = indices->rankOf(); const int indRank = indices->rankOf();
const int updRank = updates->rankOf(); const int updRank = updates->rankOf();
if (!block.isInplace())
output->assign(input);
REQUIRE_TRUE(inRank > 0, 0, "SCATTER_MUL OP: input should not be scalar !"); REQUIRE_TRUE(inRank > 0, 0, "SCATTER_MUL OP: input should not be scalar !");
if(inRank == 1) { if(inRank == 1) {
@ -66,10 +70,8 @@ namespace nd4j {
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MUL OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MUL OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
} }
if (!block.isInplace()) if (!indices->isEmpty())
output->assign(input); helpers::scatter(block.launchContext(), pairwise::Multiply, *indices, *updates, *output, lock);
helpers::scatter(block.launchContext(), pairwise::Multiply, *indices, *updates, *output, lock);
return Status::OK(); return Status::OK();
} }

View File

@ -34,6 +34,9 @@ namespace nd4j {
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
if (!block.isInplace())
output->assign(input);
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
const int inRank = input->rankOf(); const int inRank = input->rankOf();
@ -67,11 +70,9 @@ namespace nd4j {
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_SUB OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_SUB OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
} }
if (!block.isInplace()) if (!indices->isEmpty())
output->assign(input); // ScatterHelper<T>::template scatterApply<simdOps::Subtract<T>>(output, indices, updates);
helpers::scatter(block.launchContext(), pairwise::Subtract, *indices, *updates, *output, lock);
// ScatterHelper<T>::template scatterApply<simdOps::Subtract<T>>(output, indices, updates);
helpers::scatter(block.launchContext(), pairwise::Subtract, *indices, *updates, *output, lock);
return Status::OK(); return Status::OK();
} }

View File

@ -33,6 +33,9 @@ namespace nd4j {
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
if (!block.isInplace())
output->assign(input);
const bool lock = block.getBArguments()->empty() ? true : B_ARG(0); const bool lock = block.getBArguments()->empty() ? true : B_ARG(0);
const int inRank = input->rankOf(); const int inRank = input->rankOf();
@ -65,11 +68,9 @@ namespace nd4j {
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_UPD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_UPD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
} }
if (!block.isInplace()) if (!indices->isEmpty())
output->assign(input); // ScatterHelper<T>::template scatterApply<simdOps::Copy<T>>(output, indices, updates);
helpers::scatter(block.launchContext(), pairwise::CopyPws, *indices, *updates, *output, lock);
// ScatterHelper<T>::template scatterApply<simdOps::Copy<T>>(output, indices, updates);
helpers::scatter(block.launchContext(), pairwise::CopyPws, *indices, *updates, *output, lock);
return Status::OK(); return Status::OK();
} }

View File

@ -207,6 +207,46 @@ TEST_F(EmptyTests, Test_dup_1) {
delete dup; delete dup;
} }
TEST_F(EmptyTests, test_empty_scatter_1) {
auto x = NDArrayFactory::create<float>('c', {5});
auto indices = NDArrayFactory::create<int>('c', {0});
auto updates = NDArrayFactory::create<float>('c', {0});
x.linspace(1.0f);
nd4j::ops::scatter_upd op;
auto result = op.execute({&x, &indices, &updates}, {}, {}, {true});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(x, *z);
delete result;
}
TEST_F(EmptyTests, test_empty_scatter_2) {
auto x = NDArrayFactory::create<float>('c', {5});
auto z = NDArrayFactory::create<float>('c', {5});
auto indices = NDArrayFactory::create<int>('c', {0});
auto updates = NDArrayFactory::create<float>('c', {0});
x.linspace(1.0f);
Context ctx(1);
ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
ctx.setInputArray(1, indices.buffer(), indices.shapeInfo(), indices.specialBuffer(), indices.specialShapeInfo());
ctx.setInputArray(2, updates.buffer(), updates.shapeInfo(), updates.specialBuffer(), updates.specialShapeInfo());
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
bool args[] = {true};
ctx.setBArguments(args, 1);
nd4j::ops::scatter_upd op;
auto result = op.execute(&ctx);
ASSERT_EQ(Status::OK(), result);
ASSERT_EQ(x, z);
}
TEST_F(EmptyTests, test_shaped_empty_1) { TEST_F(EmptyTests, test_shaped_empty_1) {
auto empty = NDArrayFactory::create<float>('c', {2, 0, 3}); auto empty = NDArrayFactory::create<float>('c', {2, 0, 3});
std::vector<Nd4jLong> shape = {2, 0, 3}; std::vector<Nd4jLong> shape = {2, 0, 3};

View File

@ -1732,6 +1732,8 @@ public abstract class BaseNDArray implements INDArray, Iterable {
if(isEmpty()) if(isEmpty())
return this; return this;
Nd4j.getCompressor().autoDecompress(this);
// fixme: eventually it would be nice to have this in native code // fixme: eventually it would be nice to have this in native code
if (isS()) { if (isS()) {
val list = new ArrayList<String>(); val list = new ArrayList<String>();
@ -1741,8 +1743,9 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return Nd4j.create(list, this.shape(), this.ordering()); return Nd4j.create(list, this.shape(), this.ordering());
} }
Nd4j.getCompressor().autoDecompress(this); val z = Nd4j.createUninitialized(this.dataType(), this.shape(), order);
return Shape.toOffsetZeroCopy(this, order); z.assign(this);
return z;
} }
/** /**