diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/lin_space.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/lin_space.cpp index f932a1274..8d30185b1 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/lin_space.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/lin_space.cpp @@ -31,6 +31,11 @@ namespace ops { auto start = INPUT_VARIABLE(0); auto finish = INPUT_VARIABLE(1); auto numOfElements = INPUT_VARIABLE(2); + + if (numOfElements->e(0) == 1) { + output->assign(start); + return Status::OK(); + } output->linspace(start->e(0), (finish->e(0) - start->e(0)) / (numOfElements->e(0) - 1.)); return Status::OK(); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/range.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/range.cpp index 04b5b48d6..7faf82b08 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/range.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/range.cpp @@ -130,7 +130,7 @@ DECLARE_SHAPE_FN(range) { const int numIArgs = block.getIArguments()->size(); Nd4jLong steps = 0; - nd4j::DataType dataType = nd4j::DataType::INHERIT; + nd4j::DataType dataType = block.numD() ? D_ARG(0) : nd4j::DataType::INHERIT; if (numInArrs > 0) { auto isR = INPUT_VARIABLE(0)->isR(); @@ -159,7 +159,9 @@ DECLARE_SHAPE_FN(range) { REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !"); steps = static_cast((limit - start) / delta); - dataType = INPUT_VARIABLE(0)->dataType(); + + if (!block.numD()) + dataType = INPUT_VARIABLE(0)->dataType(); if(math::nd4j_abs(start + steps * delta) < math::nd4j_abs(limit)) ++steps; @@ -187,7 +189,9 @@ DECLARE_SHAPE_FN(range) { REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !"); steps = static_cast((limit - start) / delta); - dataType = INPUT_VARIABLE(0)->dataType(); + + if (!block.numD()) + dataType = INPUT_VARIABLE(0)->dataType(); if(math::nd4j_abs(start + steps * delta) < math::nd4j_abs(limit)) ++steps; @@ -214,10 +218,12 @@ DECLARE_SHAPE_FN(range) { REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !"); - if (limit > DataTypeUtils::max()) - dataType = nd4j::DataType::INT64; - else - dataType = nd4j::DataType::INT32; + if (!block.numD()) { + if (limit > DataTypeUtils::max()) + dataType = nd4j::DataType::INT64; + else + dataType = nd4j::DataType::INT32; + } steps = (limit - start) / delta; @@ -248,10 +254,13 @@ DECLARE_SHAPE_FN(range) { REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !"); steps = static_cast((limit - start) / delta); - if (Environment::getInstance()->precisionBoostAllowed()) - dataType = nd4j::DataType::DOUBLE; - else - dataType = Environment::getInstance()->defaultFloatDataType(); + + if (!block.numD()) { + if (Environment::getInstance()->precisionBoostAllowed()) + dataType = nd4j::DataType::DOUBLE; + else + dataType = Environment::getInstance()->defaultFloatDataType(); + } if(math::nd4j_abs(start + steps * delta) < math::nd4j_abs(limit)) ++steps; diff --git a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp index 6f26c1095..7c4138d36 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp @@ -830,8 +830,12 @@ namespace nd4j { template <> Nd4jStatus DeclarableOp::execute(const std::vector &inputs, const std::vector &outputs, std::initializer_list tArgs) { - std::vector realArgs(tArgs); - return execute(inputs, outputs, realArgs, std::vector(), std::vector(), std::vector());; + return execute(inputs, outputs, tArgs, std::vector(), std::vector(), std::vector()); + } + + template <> + Nd4jStatus DeclarableOp::execute(const std::vector &inputs, const std::vector &outputs, std::initializer_list dArgs) { + return execute(inputs, outputs, std::vector(), std::vector(), std::vector(), dArgs); } template <> @@ -840,13 +844,12 @@ namespace nd4j { for (auto v:tArgs) realArgs.emplace_back(v); - return execute(inputs, outputs, realArgs, std::vector(), std::vector(), std::vector());; + return execute(inputs, outputs, realArgs, std::vector(), std::vector(), std::vector()); } template <> Nd4jStatus DeclarableOp::execute(const std::vector &inputs, const std::vector &outputs, std::initializer_list iArgs) { - std::vector realArgs(iArgs); - return execute(inputs, outputs, std::vector(), realArgs, std::vector(), std::vector());; + return execute(inputs, outputs, std::vector(), iArgs, std::vector(), std::vector()); } template <> @@ -855,13 +858,12 @@ namespace nd4j { for (auto v:iArgs) realArgs.emplace_back(v); - return execute(inputs, outputs, std::vector(), realArgs, std::vector(), std::vector());; + return execute(inputs, outputs, std::vector(), realArgs, std::vector(), std::vector()); } template <> Nd4jStatus DeclarableOp::execute(const std::vector &inputs, const std::vector &outputs, std::initializer_list bArgs) { - std::vector realArgs(bArgs); - return execute(inputs, outputs, std::vector(), std::vector(), realArgs, std::vector());; + return execute(inputs, outputs, std::vector(), std::vector(), bArgs, std::vector()); } Nd4jStatus DeclarableOp::execute(const std::vector &inputs, const std::vector &outputs, const std::vector &tArgs, const std::vector &iArgs, const std::vector &bArgs, const std::vector &dArgs, bool isInplace) { @@ -903,13 +905,12 @@ namespace nd4j { for (auto v:iArgs) realArgs.emplace_back(v); - return evaluate(inputs, std::vector(), realArgs, std::vector(), std::vector());; + return evaluate(inputs, std::vector(), realArgs, std::vector(), std::vector()); } template <> nd4j::ResultSet *DeclarableOp::evaluate(const std::vector &inputs, std::initializer_list iArgs) { - std::vector realArgs(iArgs); - return evaluate(inputs, std::vector(), realArgs, std::vector(), std::vector());; + return evaluate(inputs, std::vector(), iArgs, std::vector(), std::vector()); } template <> @@ -918,19 +919,22 @@ namespace nd4j { for (auto v:tArgs) realArgs.emplace_back(v); - return evaluate(inputs, realArgs, std::vector(), std::vector(), std::vector());; + return evaluate(inputs, realArgs, std::vector(), std::vector(), std::vector()); } template <> nd4j::ResultSet *DeclarableOp::evaluate(const std::vector &inputs, std::initializer_list tArgs) { - std::vector realArgs(tArgs); - return evaluate(inputs, realArgs, std::vector(), std::vector(), std::vector());; + return evaluate(inputs, tArgs, std::vector(), std::vector(), std::vector()); } template <> nd4j::ResultSet *DeclarableOp::evaluate(const std::vector &inputs, std::initializer_list bArgs) { - std::vector realArgs(bArgs); - return evaluate(inputs, std::vector(), std::vector(), realArgs, std::vector());; + return evaluate(inputs, std::vector(), std::vector(), bArgs, std::vector()); + } + + template <> + nd4j::ResultSet *DeclarableOp::evaluate(const std::vector &inputs, std::initializer_list bArgs) { + return evaluate(inputs, std::vector(), std::vector(), std::vector(), bArgs); } nd4j::ResultSet *DeclarableOp::evaluate(const std::vector &inputs, const std::vector &tArgs, const std::vector &iArgs, const std::vector &bArgs, const std::vector &dArgs, bool isInplace) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp index e39589270..04816b2b2 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp @@ -438,6 +438,26 @@ TEST_F(DeclarableOpsTests3, Test_Range_3) { } +TEST_F(DeclarableOpsTests3, Test_Range_10) { + auto start= NDArrayFactory::create('c', {1, 1}, {0.f}); + auto stop= NDArrayFactory::create('c', {1, 1}, {2.f}); + auto step= NDArrayFactory::create('c', {1, 1}, {1.f}); + auto exp= NDArrayFactory::create('c', {2}, {0.f, 1.f}); + + nd4j::ops::range op; + auto result = op.evaluate({&start, &stop, &step}, {nd4j::DataType::DOUBLE}); + + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + + auto z = result->at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} + + TEST_F(DeclarableOpsTests3, Test_Range_4) { auto exp= NDArrayFactory::create('c', {13}, {-10.f, -8.334f, -6.668f, -5.002f, -3.336f, -1.67f, -0.004f, 1.662f, 3.328f, 4.994f, 6.66f, 8.326f, 9.992f}); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/Range.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/Range.java index d211fbe25..ade01281c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/Range.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/Range.java @@ -51,6 +51,7 @@ public class Range extends DynamicCustomOp { public Range(SameDiff sd, double from, double to, double step, DataType dataType){ super(null, sd, new SDVariable[0]); addTArgument(from, to, step); + addDArgument(dataType); this.from = from; this.to = to; this.delta = step; @@ -63,11 +64,13 @@ public class Range extends DynamicCustomOp { this.to = to; this.delta = step; this.dataType = dataType; + addDArgument(dataType); } public Range(SameDiff sd, SDVariable from, SDVariable to, SDVariable step, DataType dataType){ super(null, sd, new SDVariable[]{from, to, step}); this.dataType = dataType; + addDArgument(dataType); } @@ -99,6 +102,7 @@ public class Range extends DynamicCustomOp { if(attributesForNode.containsKey("Tidx")){ dataType = TFGraphMapper.convertType(attributesForNode.get("Tidx").getType()); } + addDArgument(dataType); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index d32aff5b1..8e638f373 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -1944,6 +1944,9 @@ public class Nd4j { if(lower == upper && num == 1) { return Nd4j.scalar(dtype, lower); } + if (num == 1) { + return Nd4j.scalar(dtype, lower); + } if (dtype.isIntType()) { return linspaceWithCustomOp(lower, upper, (int)num, dtype); } else if (dtype.isFPType()) { @@ -1964,6 +1967,9 @@ public class Nd4j { */ public static INDArray linspace(@NonNull DataType dataType, double lower, double step, long num) { Preconditions.checkState(dataType.isFPType()); + if (num == 1) + return Nd4j.scalar(dataType, lower); + return Nd4j.getExecutioner().exec(new Linspace(lower, num, step, dataType)); } @@ -1977,10 +1983,15 @@ public class Nd4j { */ public static INDArray linspace( double lower, double upper, long num, @NonNull DataType dataType) { Preconditions.checkState(dataType.isFPType()); + if (num == 1) + return Nd4j.scalar(dataType, lower); + return Nd4j.getExecutioner().exec(new Linspace(lower, upper, num, dataType)); } private static INDArray linspaceWithCustomOp(long lower, long upper, int num, DataType dataType) { + if (num == 1) + return Nd4j.scalar(dataType, lower); INDArray result = Nd4j.createUninitialized(dataType, new long[] {num}, Nd4j.order()); @@ -1994,6 +2005,8 @@ public class Nd4j { } private static INDArray linspaceWithCustomOpByRange(long lower, long upper, long num, long step, DataType dataType) { + if (num == 1) + return Nd4j.scalar(dataType, lower); INDArray result = Nd4j.createUninitialized(dataType, new long[] {num}, Nd4j.order()); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index e9d2979c6..49ff345e7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -1683,4 +1683,12 @@ public class CustomOpsTests extends BaseNd4jTest { val z = Nd4j.exec(new OnesLike(x, DataType.INT32))[0]; assertEquals(e, z); } + + @Test + public void testLinSpaceEdge_1() { + val x = Nd4j.linspace(1,10,1, DataType.FLOAT); + val e = Nd4j.scalar(1.0f); + + assertEquals(e, x); + } }