Range op data type (#204)
* - range op now accepts dargs - dargs now can be in signature Signed-off-by: raver119 <raver119@gmail.com> * range dtype java side Signed-off-by: raver119 <raver119@gmail.com> * linspace fix Signed-off-by: raver119 <raver119@gmail.com> * lin_space fix for scalar outputs Signed-off-by: raver119 <raver119@gmail.com>master
parent
d39ca6d488
commit
1ab86d1306
|
@ -31,6 +31,11 @@ namespace ops {
|
|||
auto start = INPUT_VARIABLE(0);
|
||||
auto finish = INPUT_VARIABLE(1);
|
||||
auto numOfElements = INPUT_VARIABLE(2);
|
||||
|
||||
if (numOfElements->e<Nd4jLong>(0) == 1) {
|
||||
output->assign(start);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
output->linspace(start->e<double>(0), (finish->e<double>(0) - start->e<double>(0)) / (numOfElements->e<Nd4jLong>(0) - 1.));
|
||||
return Status::OK();
|
||||
|
|
|
@ -130,7 +130,7 @@ DECLARE_SHAPE_FN(range) {
|
|||
const int numIArgs = block.getIArguments()->size();
|
||||
|
||||
Nd4jLong steps = 0;
|
||||
nd4j::DataType dataType = nd4j::DataType::INHERIT;
|
||||
nd4j::DataType dataType = block.numD() ? D_ARG(0) : nd4j::DataType::INHERIT;
|
||||
|
||||
if (numInArrs > 0) {
|
||||
auto isR = INPUT_VARIABLE(0)->isR();
|
||||
|
@ -159,7 +159,9 @@ DECLARE_SHAPE_FN(range) {
|
|||
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
|
||||
|
||||
steps = static_cast<Nd4jLong >((limit - start) / delta);
|
||||
dataType = INPUT_VARIABLE(0)->dataType();
|
||||
|
||||
if (!block.numD())
|
||||
dataType = INPUT_VARIABLE(0)->dataType();
|
||||
|
||||
if(math::nd4j_abs<double>(start + steps * delta) < math::nd4j_abs<double >(limit))
|
||||
++steps;
|
||||
|
@ -187,7 +189,9 @@ DECLARE_SHAPE_FN(range) {
|
|||
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
|
||||
|
||||
steps = static_cast<Nd4jLong >((limit - start) / delta);
|
||||
dataType = INPUT_VARIABLE(0)->dataType();
|
||||
|
||||
if (!block.numD())
|
||||
dataType = INPUT_VARIABLE(0)->dataType();
|
||||
|
||||
if(math::nd4j_abs<double>(start + steps * delta) < math::nd4j_abs<double >(limit))
|
||||
++steps;
|
||||
|
@ -214,10 +218,12 @@ DECLARE_SHAPE_FN(range) {
|
|||
|
||||
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
|
||||
|
||||
if (limit > DataTypeUtils::max<int>())
|
||||
dataType = nd4j::DataType::INT64;
|
||||
else
|
||||
dataType = nd4j::DataType::INT32;
|
||||
if (!block.numD()) {
|
||||
if (limit > DataTypeUtils::max<int>())
|
||||
dataType = nd4j::DataType::INT64;
|
||||
else
|
||||
dataType = nd4j::DataType::INT32;
|
||||
}
|
||||
|
||||
steps = (limit - start) / delta;
|
||||
|
||||
|
@ -248,10 +254,13 @@ DECLARE_SHAPE_FN(range) {
|
|||
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
|
||||
|
||||
steps = static_cast<Nd4jLong >((limit - start) / delta);
|
||||
if (Environment::getInstance()->precisionBoostAllowed())
|
||||
dataType = nd4j::DataType::DOUBLE;
|
||||
else
|
||||
dataType = Environment::getInstance()->defaultFloatDataType();
|
||||
|
||||
if (!block.numD()) {
|
||||
if (Environment::getInstance()->precisionBoostAllowed())
|
||||
dataType = nd4j::DataType::DOUBLE;
|
||||
else
|
||||
dataType = Environment::getInstance()->defaultFloatDataType();
|
||||
}
|
||||
|
||||
if(math::nd4j_abs<double>(start + steps * delta) < math::nd4j_abs<double >(limit))
|
||||
++steps;
|
||||
|
|
|
@ -830,8 +830,12 @@ namespace nd4j {
|
|||
|
||||
template <>
|
||||
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<double> tArgs) {
|
||||
std::vector<double> realArgs(tArgs);
|
||||
return execute(inputs, outputs, realArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());;
|
||||
return execute(inputs, outputs, tArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());
|
||||
}
|
||||
|
||||
template <>
|
||||
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<nd4j::DataType> dArgs) {
|
||||
return execute(inputs, outputs, std::vector<double>(), std::vector<Nd4jLong>(), std::vector<bool>(), dArgs);
|
||||
}
|
||||
|
||||
template <>
|
||||
|
@ -840,13 +844,12 @@ namespace nd4j {
|
|||
for (auto v:tArgs)
|
||||
realArgs.emplace_back(v);
|
||||
|
||||
return execute(inputs, outputs, realArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());;
|
||||
return execute(inputs, outputs, realArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());
|
||||
}
|
||||
|
||||
template <>
|
||||
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<Nd4jLong> iArgs) {
|
||||
std::vector<Nd4jLong> realArgs(iArgs);
|
||||
return execute(inputs, outputs, std::vector<double>(), realArgs, std::vector<bool>(), std::vector<nd4j::DataType>());;
|
||||
return execute(inputs, outputs, std::vector<double>(), iArgs, std::vector<bool>(), std::vector<nd4j::DataType>());
|
||||
}
|
||||
|
||||
template <>
|
||||
|
@ -855,13 +858,12 @@ namespace nd4j {
|
|||
for (auto v:iArgs)
|
||||
realArgs.emplace_back(v);
|
||||
|
||||
return execute(inputs, outputs, std::vector<double>(), realArgs, std::vector<bool>(), std::vector<nd4j::DataType>());;
|
||||
return execute(inputs, outputs, std::vector<double>(), realArgs, std::vector<bool>(), std::vector<nd4j::DataType>());
|
||||
}
|
||||
|
||||
template <>
|
||||
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<bool> bArgs) {
|
||||
std::vector<bool> realArgs(bArgs);
|
||||
return execute(inputs, outputs, std::vector<double>(), std::vector<Nd4jLong>(), realArgs, std::vector<nd4j::DataType>());;
|
||||
return execute(inputs, outputs, std::vector<double>(), std::vector<Nd4jLong>(), bArgs, std::vector<nd4j::DataType>());
|
||||
}
|
||||
|
||||
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs, const std::vector<nd4j::DataType> &dArgs, bool isInplace) {
|
||||
|
@ -903,13 +905,12 @@ namespace nd4j {
|
|||
for (auto v:iArgs)
|
||||
realArgs.emplace_back(v);
|
||||
|
||||
return evaluate(inputs, std::vector<double>(), realArgs, std::vector<bool>(), std::vector<nd4j::DataType>());;
|
||||
return evaluate(inputs, std::vector<double>(), realArgs, std::vector<bool>(), std::vector<nd4j::DataType>());
|
||||
}
|
||||
|
||||
template <>
|
||||
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<Nd4jLong> iArgs) {
|
||||
std::vector<Nd4jLong> realArgs(iArgs);
|
||||
return evaluate(inputs, std::vector<double>(), realArgs, std::vector<bool>(), std::vector<nd4j::DataType>());;
|
||||
return evaluate(inputs, std::vector<double>(), iArgs, std::vector<bool>(), std::vector<nd4j::DataType>());
|
||||
}
|
||||
|
||||
template <>
|
||||
|
@ -918,19 +919,22 @@ namespace nd4j {
|
|||
for (auto v:tArgs)
|
||||
realArgs.emplace_back(v);
|
||||
|
||||
return evaluate(inputs, realArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());;
|
||||
return evaluate(inputs, realArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());
|
||||
}
|
||||
|
||||
template <>
|
||||
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<double> tArgs) {
|
||||
std::vector<double> realArgs(tArgs);
|
||||
return evaluate(inputs, realArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());;
|
||||
return evaluate(inputs, tArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());
|
||||
}
|
||||
|
||||
template <>
|
||||
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<bool> bArgs) {
|
||||
std::vector<bool> realArgs(bArgs);
|
||||
return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), realArgs, std::vector<nd4j::DataType>());;
|
||||
return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), bArgs, std::vector<nd4j::DataType>());
|
||||
}
|
||||
|
||||
template <>
|
||||
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<nd4j::DataType> bArgs) {
|
||||
return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), std::vector<bool>(), bArgs);
|
||||
}
|
||||
|
||||
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs, const std::vector<nd4j::DataType> &dArgs, bool isInplace) {
|
||||
|
|
|
@ -438,6 +438,26 @@ TEST_F(DeclarableOpsTests3, Test_Range_3) {
|
|||
}
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests3, Test_Range_10) {
|
||||
auto start= NDArrayFactory::create<float>('c', {1, 1}, {0.f});
|
||||
auto stop= NDArrayFactory::create<float>('c', {1, 1}, {2.f});
|
||||
auto step= NDArrayFactory::create<float>('c', {1, 1}, {1.f});
|
||||
auto exp= NDArrayFactory::create<double>('c', {2}, {0.f, 1.f});
|
||||
|
||||
nd4j::ops::range op;
|
||||
auto result = op.evaluate({&start, &stop, &step}, {nd4j::DataType::DOUBLE});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests3, Test_Range_4) {
|
||||
auto exp= NDArrayFactory::create<float>('c', {13}, {-10.f, -8.334f, -6.668f, -5.002f, -3.336f, -1.67f, -0.004f, 1.662f, 3.328f, 4.994f, 6.66f, 8.326f, 9.992f});
|
||||
|
||||
|
|
|
@ -51,6 +51,7 @@ public class Range extends DynamicCustomOp {
|
|||
public Range(SameDiff sd, double from, double to, double step, DataType dataType){
|
||||
super(null, sd, new SDVariable[0]);
|
||||
addTArgument(from, to, step);
|
||||
addDArgument(dataType);
|
||||
this.from = from;
|
||||
this.to = to;
|
||||
this.delta = step;
|
||||
|
@ -63,11 +64,13 @@ public class Range extends DynamicCustomOp {
|
|||
this.to = to;
|
||||
this.delta = step;
|
||||
this.dataType = dataType;
|
||||
addDArgument(dataType);
|
||||
}
|
||||
|
||||
public Range(SameDiff sd, SDVariable from, SDVariable to, SDVariable step, DataType dataType){
|
||||
super(null, sd, new SDVariable[]{from, to, step});
|
||||
this.dataType = dataType;
|
||||
addDArgument(dataType);
|
||||
}
|
||||
|
||||
|
||||
|
@ -99,6 +102,7 @@ public class Range extends DynamicCustomOp {
|
|||
if(attributesForNode.containsKey("Tidx")){
|
||||
dataType = TFGraphMapper.convertType(attributesForNode.get("Tidx").getType());
|
||||
}
|
||||
addDArgument(dataType);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -1944,6 +1944,9 @@ public class Nd4j {
|
|||
if(lower == upper && num == 1) {
|
||||
return Nd4j.scalar(dtype, lower);
|
||||
}
|
||||
if (num == 1) {
|
||||
return Nd4j.scalar(dtype, lower);
|
||||
}
|
||||
if (dtype.isIntType()) {
|
||||
return linspaceWithCustomOp(lower, upper, (int)num, dtype);
|
||||
} else if (dtype.isFPType()) {
|
||||
|
@ -1964,6 +1967,9 @@ public class Nd4j {
|
|||
*/
|
||||
public static INDArray linspace(@NonNull DataType dataType, double lower, double step, long num) {
|
||||
Preconditions.checkState(dataType.isFPType());
|
||||
if (num == 1)
|
||||
return Nd4j.scalar(dataType, lower);
|
||||
|
||||
return Nd4j.getExecutioner().exec(new Linspace(lower, num, step, dataType));
|
||||
}
|
||||
|
||||
|
@ -1977,10 +1983,15 @@ public class Nd4j {
|
|||
*/
|
||||
public static INDArray linspace( double lower, double upper, long num, @NonNull DataType dataType) {
|
||||
Preconditions.checkState(dataType.isFPType());
|
||||
if (num == 1)
|
||||
return Nd4j.scalar(dataType, lower);
|
||||
|
||||
return Nd4j.getExecutioner().exec(new Linspace(lower, upper, num, dataType));
|
||||
}
|
||||
|
||||
private static INDArray linspaceWithCustomOp(long lower, long upper, int num, DataType dataType) {
|
||||
if (num == 1)
|
||||
return Nd4j.scalar(dataType, lower);
|
||||
|
||||
INDArray result = Nd4j.createUninitialized(dataType, new long[] {num}, Nd4j.order());
|
||||
|
||||
|
@ -1994,6 +2005,8 @@ public class Nd4j {
|
|||
}
|
||||
|
||||
private static INDArray linspaceWithCustomOpByRange(long lower, long upper, long num, long step, DataType dataType) {
|
||||
if (num == 1)
|
||||
return Nd4j.scalar(dataType, lower);
|
||||
|
||||
INDArray result = Nd4j.createUninitialized(dataType, new long[] {num}, Nd4j.order());
|
||||
|
||||
|
|
|
@ -1683,4 +1683,12 @@ public class CustomOpsTests extends BaseNd4jTest {
|
|||
val z = Nd4j.exec(new OnesLike(x, DataType.INT32))[0];
|
||||
assertEquals(e, z);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLinSpaceEdge_1() {
|
||||
val x = Nd4j.linspace(1,10,1, DataType.FLOAT);
|
||||
val e = Nd4j.scalar(1.0f);
|
||||
|
||||
assertEquals(e, x);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue