Range op data type (#204)

* - range op now accepts dargs
- dargs now can be in signature

Signed-off-by: raver119 <raver119@gmail.com>

* range dtype java side

Signed-off-by: raver119 <raver119@gmail.com>

* linspace fix

Signed-off-by: raver119 <raver119@gmail.com>

* lin_space fix for scalar outputs

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2020-01-31 10:45:41 +03:00 committed by GitHub
parent d39ca6d488
commit 1ab86d1306
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 90 additions and 27 deletions

View File

@ -31,6 +31,11 @@ namespace ops {
auto start = INPUT_VARIABLE(0);
auto finish = INPUT_VARIABLE(1);
auto numOfElements = INPUT_VARIABLE(2);
if (numOfElements->e<Nd4jLong>(0) == 1) {
output->assign(start);
return Status::OK();
}
output->linspace(start->e<double>(0), (finish->e<double>(0) - start->e<double>(0)) / (numOfElements->e<Nd4jLong>(0) - 1.));
return Status::OK();

View File

@ -130,7 +130,7 @@ DECLARE_SHAPE_FN(range) {
const int numIArgs = block.getIArguments()->size();
Nd4jLong steps = 0;
nd4j::DataType dataType = nd4j::DataType::INHERIT;
nd4j::DataType dataType = block.numD() ? D_ARG(0) : nd4j::DataType::INHERIT;
if (numInArrs > 0) {
auto isR = INPUT_VARIABLE(0)->isR();
@ -159,7 +159,9 @@ DECLARE_SHAPE_FN(range) {
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
steps = static_cast<Nd4jLong >((limit - start) / delta);
dataType = INPUT_VARIABLE(0)->dataType();
if (!block.numD())
dataType = INPUT_VARIABLE(0)->dataType();
if(math::nd4j_abs<double>(start + steps * delta) < math::nd4j_abs<double >(limit))
++steps;
@ -187,7 +189,9 @@ DECLARE_SHAPE_FN(range) {
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
steps = static_cast<Nd4jLong >((limit - start) / delta);
dataType = INPUT_VARIABLE(0)->dataType();
if (!block.numD())
dataType = INPUT_VARIABLE(0)->dataType();
if(math::nd4j_abs<double>(start + steps * delta) < math::nd4j_abs<double >(limit))
++steps;
@ -214,10 +218,12 @@ DECLARE_SHAPE_FN(range) {
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
if (limit > DataTypeUtils::max<int>())
dataType = nd4j::DataType::INT64;
else
dataType = nd4j::DataType::INT32;
if (!block.numD()) {
if (limit > DataTypeUtils::max<int>())
dataType = nd4j::DataType::INT64;
else
dataType = nd4j::DataType::INT32;
}
steps = (limit - start) / delta;
@ -248,10 +254,13 @@ DECLARE_SHAPE_FN(range) {
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
steps = static_cast<Nd4jLong >((limit - start) / delta);
if (Environment::getInstance()->precisionBoostAllowed())
dataType = nd4j::DataType::DOUBLE;
else
dataType = Environment::getInstance()->defaultFloatDataType();
if (!block.numD()) {
if (Environment::getInstance()->precisionBoostAllowed())
dataType = nd4j::DataType::DOUBLE;
else
dataType = Environment::getInstance()->defaultFloatDataType();
}
if(math::nd4j_abs<double>(start + steps * delta) < math::nd4j_abs<double >(limit))
++steps;

View File

@ -830,8 +830,12 @@ namespace nd4j {
template <>
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<double> tArgs) {
std::vector<double> realArgs(tArgs);
return execute(inputs, outputs, realArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());;
return execute(inputs, outputs, tArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());
}
template <>
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<nd4j::DataType> dArgs) {
return execute(inputs, outputs, std::vector<double>(), std::vector<Nd4jLong>(), std::vector<bool>(), dArgs);
}
template <>
@ -840,13 +844,12 @@ namespace nd4j {
for (auto v:tArgs)
realArgs.emplace_back(v);
return execute(inputs, outputs, realArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());;
return execute(inputs, outputs, realArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());
}
template <>
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<Nd4jLong> iArgs) {
std::vector<Nd4jLong> realArgs(iArgs);
return execute(inputs, outputs, std::vector<double>(), realArgs, std::vector<bool>(), std::vector<nd4j::DataType>());;
return execute(inputs, outputs, std::vector<double>(), iArgs, std::vector<bool>(), std::vector<nd4j::DataType>());
}
template <>
@ -855,13 +858,12 @@ namespace nd4j {
for (auto v:iArgs)
realArgs.emplace_back(v);
return execute(inputs, outputs, std::vector<double>(), realArgs, std::vector<bool>(), std::vector<nd4j::DataType>());;
return execute(inputs, outputs, std::vector<double>(), realArgs, std::vector<bool>(), std::vector<nd4j::DataType>());
}
template <>
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<bool> bArgs) {
std::vector<bool> realArgs(bArgs);
return execute(inputs, outputs, std::vector<double>(), std::vector<Nd4jLong>(), realArgs, std::vector<nd4j::DataType>());;
return execute(inputs, outputs, std::vector<double>(), std::vector<Nd4jLong>(), bArgs, std::vector<nd4j::DataType>());
}
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs, const std::vector<nd4j::DataType> &dArgs, bool isInplace) {
@ -903,13 +905,12 @@ namespace nd4j {
for (auto v:iArgs)
realArgs.emplace_back(v);
return evaluate(inputs, std::vector<double>(), realArgs, std::vector<bool>(), std::vector<nd4j::DataType>());;
return evaluate(inputs, std::vector<double>(), realArgs, std::vector<bool>(), std::vector<nd4j::DataType>());
}
template <>
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<Nd4jLong> iArgs) {
std::vector<Nd4jLong> realArgs(iArgs);
return evaluate(inputs, std::vector<double>(), realArgs, std::vector<bool>(), std::vector<nd4j::DataType>());;
return evaluate(inputs, std::vector<double>(), iArgs, std::vector<bool>(), std::vector<nd4j::DataType>());
}
template <>
@ -918,19 +919,22 @@ namespace nd4j {
for (auto v:tArgs)
realArgs.emplace_back(v);
return evaluate(inputs, realArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());;
return evaluate(inputs, realArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());
}
template <>
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<double> tArgs) {
std::vector<double> realArgs(tArgs);
return evaluate(inputs, realArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());;
return evaluate(inputs, tArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());
}
template <>
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<bool> bArgs) {
std::vector<bool> realArgs(bArgs);
return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), realArgs, std::vector<nd4j::DataType>());;
return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), bArgs, std::vector<nd4j::DataType>());
}
template <>
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<nd4j::DataType> bArgs) {
return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), std::vector<bool>(), bArgs);
}
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs, const std::vector<nd4j::DataType> &dArgs, bool isInplace) {

View File

@ -438,6 +438,26 @@ TEST_F(DeclarableOpsTests3, Test_Range_3) {
}
TEST_F(DeclarableOpsTests3, Test_Range_10) {
auto start= NDArrayFactory::create<float>('c', {1, 1}, {0.f});
auto stop= NDArrayFactory::create<float>('c', {1, 1}, {2.f});
auto step= NDArrayFactory::create<float>('c', {1, 1}, {1.f});
auto exp= NDArrayFactory::create<double>('c', {2}, {0.f, 1.f});
nd4j::ops::range op;
auto result = op.evaluate({&start, &stop, &step}, {nd4j::DataType::DOUBLE});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests3, Test_Range_4) {
auto exp= NDArrayFactory::create<float>('c', {13}, {-10.f, -8.334f, -6.668f, -5.002f, -3.336f, -1.67f, -0.004f, 1.662f, 3.328f, 4.994f, 6.66f, 8.326f, 9.992f});

View File

@ -51,6 +51,7 @@ public class Range extends DynamicCustomOp {
public Range(SameDiff sd, double from, double to, double step, DataType dataType){
super(null, sd, new SDVariable[0]);
addTArgument(from, to, step);
addDArgument(dataType);
this.from = from;
this.to = to;
this.delta = step;
@ -63,11 +64,13 @@ public class Range extends DynamicCustomOp {
this.to = to;
this.delta = step;
this.dataType = dataType;
addDArgument(dataType);
}
public Range(SameDiff sd, SDVariable from, SDVariable to, SDVariable step, DataType dataType){
super(null, sd, new SDVariable[]{from, to, step});
this.dataType = dataType;
addDArgument(dataType);
}
@ -99,6 +102,7 @@ public class Range extends DynamicCustomOp {
if(attributesForNode.containsKey("Tidx")){
dataType = TFGraphMapper.convertType(attributesForNode.get("Tidx").getType());
}
addDArgument(dataType);
}
@Override

View File

@ -1944,6 +1944,9 @@ public class Nd4j {
if(lower == upper && num == 1) {
return Nd4j.scalar(dtype, lower);
}
if (num == 1) {
return Nd4j.scalar(dtype, lower);
}
if (dtype.isIntType()) {
return linspaceWithCustomOp(lower, upper, (int)num, dtype);
} else if (dtype.isFPType()) {
@ -1964,6 +1967,9 @@ public class Nd4j {
*/
public static INDArray linspace(@NonNull DataType dataType, double lower, double step, long num) {
Preconditions.checkState(dataType.isFPType());
if (num == 1)
return Nd4j.scalar(dataType, lower);
return Nd4j.getExecutioner().exec(new Linspace(lower, num, step, dataType));
}
@ -1977,10 +1983,15 @@ public class Nd4j {
*/
public static INDArray linspace( double lower, double upper, long num, @NonNull DataType dataType) {
Preconditions.checkState(dataType.isFPType());
if (num == 1)
return Nd4j.scalar(dataType, lower);
return Nd4j.getExecutioner().exec(new Linspace(lower, upper, num, dataType));
}
private static INDArray linspaceWithCustomOp(long lower, long upper, int num, DataType dataType) {
if (num == 1)
return Nd4j.scalar(dataType, lower);
INDArray result = Nd4j.createUninitialized(dataType, new long[] {num}, Nd4j.order());
@ -1994,6 +2005,8 @@ public class Nd4j {
}
private static INDArray linspaceWithCustomOpByRange(long lower, long upper, long num, long step, DataType dataType) {
if (num == 1)
return Nd4j.scalar(dataType, lower);
INDArray result = Nd4j.createUninitialized(dataType, new long[] {num}, Nd4j.order());

View File

@ -1683,4 +1683,12 @@ public class CustomOpsTests extends BaseNd4jTest {
val z = Nd4j.exec(new OnesLike(x, DataType.INT32))[0];
assertEquals(e, z);
}
@Test
public void testLinSpaceEdge_1() {
val x = Nd4j.linspace(1,10,1, DataType.FLOAT);
val e = Nd4j.scalar(1.0f);
assertEquals(e, x);
}
}