Range op data type (#204)
* - range op now accepts dargs - dargs now can be in signature Signed-off-by: raver119 <raver119@gmail.com> * range dtype java side Signed-off-by: raver119 <raver119@gmail.com> * linspace fix Signed-off-by: raver119 <raver119@gmail.com> * lin_space fix for scalar outputs Signed-off-by: raver119 <raver119@gmail.com>master
parent
d39ca6d488
commit
1ab86d1306
|
@ -32,6 +32,11 @@ namespace ops {
|
||||||
auto finish = INPUT_VARIABLE(1);
|
auto finish = INPUT_VARIABLE(1);
|
||||||
auto numOfElements = INPUT_VARIABLE(2);
|
auto numOfElements = INPUT_VARIABLE(2);
|
||||||
|
|
||||||
|
if (numOfElements->e<Nd4jLong>(0) == 1) {
|
||||||
|
output->assign(start);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
output->linspace(start->e<double>(0), (finish->e<double>(0) - start->e<double>(0)) / (numOfElements->e<Nd4jLong>(0) - 1.));
|
output->linspace(start->e<double>(0), (finish->e<double>(0) - start->e<double>(0)) / (numOfElements->e<Nd4jLong>(0) - 1.));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -130,7 +130,7 @@ DECLARE_SHAPE_FN(range) {
|
||||||
const int numIArgs = block.getIArguments()->size();
|
const int numIArgs = block.getIArguments()->size();
|
||||||
|
|
||||||
Nd4jLong steps = 0;
|
Nd4jLong steps = 0;
|
||||||
nd4j::DataType dataType = nd4j::DataType::INHERIT;
|
nd4j::DataType dataType = block.numD() ? D_ARG(0) : nd4j::DataType::INHERIT;
|
||||||
|
|
||||||
if (numInArrs > 0) {
|
if (numInArrs > 0) {
|
||||||
auto isR = INPUT_VARIABLE(0)->isR();
|
auto isR = INPUT_VARIABLE(0)->isR();
|
||||||
|
@ -159,6 +159,8 @@ DECLARE_SHAPE_FN(range) {
|
||||||
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
|
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
|
||||||
|
|
||||||
steps = static_cast<Nd4jLong >((limit - start) / delta);
|
steps = static_cast<Nd4jLong >((limit - start) / delta);
|
||||||
|
|
||||||
|
if (!block.numD())
|
||||||
dataType = INPUT_VARIABLE(0)->dataType();
|
dataType = INPUT_VARIABLE(0)->dataType();
|
||||||
|
|
||||||
if(math::nd4j_abs<double>(start + steps * delta) < math::nd4j_abs<double >(limit))
|
if(math::nd4j_abs<double>(start + steps * delta) < math::nd4j_abs<double >(limit))
|
||||||
|
@ -187,6 +189,8 @@ DECLARE_SHAPE_FN(range) {
|
||||||
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
|
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
|
||||||
|
|
||||||
steps = static_cast<Nd4jLong >((limit - start) / delta);
|
steps = static_cast<Nd4jLong >((limit - start) / delta);
|
||||||
|
|
||||||
|
if (!block.numD())
|
||||||
dataType = INPUT_VARIABLE(0)->dataType();
|
dataType = INPUT_VARIABLE(0)->dataType();
|
||||||
|
|
||||||
if(math::nd4j_abs<double>(start + steps * delta) < math::nd4j_abs<double >(limit))
|
if(math::nd4j_abs<double>(start + steps * delta) < math::nd4j_abs<double >(limit))
|
||||||
|
@ -214,10 +218,12 @@ DECLARE_SHAPE_FN(range) {
|
||||||
|
|
||||||
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
|
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
|
||||||
|
|
||||||
|
if (!block.numD()) {
|
||||||
if (limit > DataTypeUtils::max<int>())
|
if (limit > DataTypeUtils::max<int>())
|
||||||
dataType = nd4j::DataType::INT64;
|
dataType = nd4j::DataType::INT64;
|
||||||
else
|
else
|
||||||
dataType = nd4j::DataType::INT32;
|
dataType = nd4j::DataType::INT32;
|
||||||
|
}
|
||||||
|
|
||||||
steps = (limit - start) / delta;
|
steps = (limit - start) / delta;
|
||||||
|
|
||||||
|
@ -248,10 +254,13 @@ DECLARE_SHAPE_FN(range) {
|
||||||
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
|
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
|
||||||
|
|
||||||
steps = static_cast<Nd4jLong >((limit - start) / delta);
|
steps = static_cast<Nd4jLong >((limit - start) / delta);
|
||||||
|
|
||||||
|
if (!block.numD()) {
|
||||||
if (Environment::getInstance()->precisionBoostAllowed())
|
if (Environment::getInstance()->precisionBoostAllowed())
|
||||||
dataType = nd4j::DataType::DOUBLE;
|
dataType = nd4j::DataType::DOUBLE;
|
||||||
else
|
else
|
||||||
dataType = Environment::getInstance()->defaultFloatDataType();
|
dataType = Environment::getInstance()->defaultFloatDataType();
|
||||||
|
}
|
||||||
|
|
||||||
if(math::nd4j_abs<double>(start + steps * delta) < math::nd4j_abs<double >(limit))
|
if(math::nd4j_abs<double>(start + steps * delta) < math::nd4j_abs<double >(limit))
|
||||||
++steps;
|
++steps;
|
||||||
|
|
|
@ -830,8 +830,12 @@ namespace nd4j {
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<double> tArgs) {
|
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<double> tArgs) {
|
||||||
std::vector<double> realArgs(tArgs);
|
return execute(inputs, outputs, tArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());
|
||||||
return execute(inputs, outputs, realArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());;
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<nd4j::DataType> dArgs) {
|
||||||
|
return execute(inputs, outputs, std::vector<double>(), std::vector<Nd4jLong>(), std::vector<bool>(), dArgs);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
@ -840,13 +844,12 @@ namespace nd4j {
|
||||||
for (auto v:tArgs)
|
for (auto v:tArgs)
|
||||||
realArgs.emplace_back(v);
|
realArgs.emplace_back(v);
|
||||||
|
|
||||||
return execute(inputs, outputs, realArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());;
|
return execute(inputs, outputs, realArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<Nd4jLong> iArgs) {
|
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<Nd4jLong> iArgs) {
|
||||||
std::vector<Nd4jLong> realArgs(iArgs);
|
return execute(inputs, outputs, std::vector<double>(), iArgs, std::vector<bool>(), std::vector<nd4j::DataType>());
|
||||||
return execute(inputs, outputs, std::vector<double>(), realArgs, std::vector<bool>(), std::vector<nd4j::DataType>());;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
@ -855,13 +858,12 @@ namespace nd4j {
|
||||||
for (auto v:iArgs)
|
for (auto v:iArgs)
|
||||||
realArgs.emplace_back(v);
|
realArgs.emplace_back(v);
|
||||||
|
|
||||||
return execute(inputs, outputs, std::vector<double>(), realArgs, std::vector<bool>(), std::vector<nd4j::DataType>());;
|
return execute(inputs, outputs, std::vector<double>(), realArgs, std::vector<bool>(), std::vector<nd4j::DataType>());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<bool> bArgs) {
|
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, std::initializer_list<bool> bArgs) {
|
||||||
std::vector<bool> realArgs(bArgs);
|
return execute(inputs, outputs, std::vector<double>(), std::vector<Nd4jLong>(), bArgs, std::vector<nd4j::DataType>());
|
||||||
return execute(inputs, outputs, std::vector<double>(), std::vector<Nd4jLong>(), realArgs, std::vector<nd4j::DataType>());;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs, const std::vector<nd4j::DataType> &dArgs, bool isInplace) {
|
Nd4jStatus DeclarableOp::execute(const std::vector<NDArray *> &inputs, const std::vector<NDArray *> &outputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs, const std::vector<nd4j::DataType> &dArgs, bool isInplace) {
|
||||||
|
@ -903,13 +905,12 @@ namespace nd4j {
|
||||||
for (auto v:iArgs)
|
for (auto v:iArgs)
|
||||||
realArgs.emplace_back(v);
|
realArgs.emplace_back(v);
|
||||||
|
|
||||||
return evaluate(inputs, std::vector<double>(), realArgs, std::vector<bool>(), std::vector<nd4j::DataType>());;
|
return evaluate(inputs, std::vector<double>(), realArgs, std::vector<bool>(), std::vector<nd4j::DataType>());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<Nd4jLong> iArgs) {
|
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<Nd4jLong> iArgs) {
|
||||||
std::vector<Nd4jLong> realArgs(iArgs);
|
return evaluate(inputs, std::vector<double>(), iArgs, std::vector<bool>(), std::vector<nd4j::DataType>());
|
||||||
return evaluate(inputs, std::vector<double>(), realArgs, std::vector<bool>(), std::vector<nd4j::DataType>());;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
@ -918,19 +919,22 @@ namespace nd4j {
|
||||||
for (auto v:tArgs)
|
for (auto v:tArgs)
|
||||||
realArgs.emplace_back(v);
|
realArgs.emplace_back(v);
|
||||||
|
|
||||||
return evaluate(inputs, realArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());;
|
return evaluate(inputs, realArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<double> tArgs) {
|
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<double> tArgs) {
|
||||||
std::vector<double> realArgs(tArgs);
|
return evaluate(inputs, tArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());
|
||||||
return evaluate(inputs, realArgs, std::vector<Nd4jLong>(), std::vector<bool>(), std::vector<nd4j::DataType>());;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<bool> bArgs) {
|
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<bool> bArgs) {
|
||||||
std::vector<bool> realArgs(bArgs);
|
return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), bArgs, std::vector<nd4j::DataType>());
|
||||||
return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), realArgs, std::vector<nd4j::DataType>());;
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, std::initializer_list<nd4j::DataType> bArgs) {
|
||||||
|
return evaluate(inputs, std::vector<double>(), std::vector<Nd4jLong>(), std::vector<bool>(), bArgs);
|
||||||
}
|
}
|
||||||
|
|
||||||
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs, const std::vector<nd4j::DataType> &dArgs, bool isInplace) {
|
nd4j::ResultSet *DeclarableOp::evaluate(const std::vector<NDArray *> &inputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs, const std::vector<nd4j::DataType> &dArgs, bool isInplace) {
|
||||||
|
|
|
@ -438,6 +438,26 @@ TEST_F(DeclarableOpsTests3, Test_Range_3) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests3, Test_Range_10) {
|
||||||
|
auto start= NDArrayFactory::create<float>('c', {1, 1}, {0.f});
|
||||||
|
auto stop= NDArrayFactory::create<float>('c', {1, 1}, {2.f});
|
||||||
|
auto step= NDArrayFactory::create<float>('c', {1, 1}, {1.f});
|
||||||
|
auto exp= NDArrayFactory::create<double>('c', {2}, {0.f, 1.f});
|
||||||
|
|
||||||
|
nd4j::ops::range op;
|
||||||
|
auto result = op.evaluate({&start, &stop, &step}, {nd4j::DataType::DOUBLE});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests3, Test_Range_4) {
|
TEST_F(DeclarableOpsTests3, Test_Range_4) {
|
||||||
auto exp= NDArrayFactory::create<float>('c', {13}, {-10.f, -8.334f, -6.668f, -5.002f, -3.336f, -1.67f, -0.004f, 1.662f, 3.328f, 4.994f, 6.66f, 8.326f, 9.992f});
|
auto exp= NDArrayFactory::create<float>('c', {13}, {-10.f, -8.334f, -6.668f, -5.002f, -3.336f, -1.67f, -0.004f, 1.662f, 3.328f, 4.994f, 6.66f, 8.326f, 9.992f});
|
||||||
|
|
||||||
|
|
|
@ -51,6 +51,7 @@ public class Range extends DynamicCustomOp {
|
||||||
public Range(SameDiff sd, double from, double to, double step, DataType dataType){
|
public Range(SameDiff sd, double from, double to, double step, DataType dataType){
|
||||||
super(null, sd, new SDVariable[0]);
|
super(null, sd, new SDVariable[0]);
|
||||||
addTArgument(from, to, step);
|
addTArgument(from, to, step);
|
||||||
|
addDArgument(dataType);
|
||||||
this.from = from;
|
this.from = from;
|
||||||
this.to = to;
|
this.to = to;
|
||||||
this.delta = step;
|
this.delta = step;
|
||||||
|
@ -63,11 +64,13 @@ public class Range extends DynamicCustomOp {
|
||||||
this.to = to;
|
this.to = to;
|
||||||
this.delta = step;
|
this.delta = step;
|
||||||
this.dataType = dataType;
|
this.dataType = dataType;
|
||||||
|
addDArgument(dataType);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Range(SameDiff sd, SDVariable from, SDVariable to, SDVariable step, DataType dataType){
|
public Range(SameDiff sd, SDVariable from, SDVariable to, SDVariable step, DataType dataType){
|
||||||
super(null, sd, new SDVariable[]{from, to, step});
|
super(null, sd, new SDVariable[]{from, to, step});
|
||||||
this.dataType = dataType;
|
this.dataType = dataType;
|
||||||
|
addDArgument(dataType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -99,6 +102,7 @@ public class Range extends DynamicCustomOp {
|
||||||
if(attributesForNode.containsKey("Tidx")){
|
if(attributesForNode.containsKey("Tidx")){
|
||||||
dataType = TFGraphMapper.convertType(attributesForNode.get("Tidx").getType());
|
dataType = TFGraphMapper.convertType(attributesForNode.get("Tidx").getType());
|
||||||
}
|
}
|
||||||
|
addDArgument(dataType);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -1944,6 +1944,9 @@ public class Nd4j {
|
||||||
if(lower == upper && num == 1) {
|
if(lower == upper && num == 1) {
|
||||||
return Nd4j.scalar(dtype, lower);
|
return Nd4j.scalar(dtype, lower);
|
||||||
}
|
}
|
||||||
|
if (num == 1) {
|
||||||
|
return Nd4j.scalar(dtype, lower);
|
||||||
|
}
|
||||||
if (dtype.isIntType()) {
|
if (dtype.isIntType()) {
|
||||||
return linspaceWithCustomOp(lower, upper, (int)num, dtype);
|
return linspaceWithCustomOp(lower, upper, (int)num, dtype);
|
||||||
} else if (dtype.isFPType()) {
|
} else if (dtype.isFPType()) {
|
||||||
|
@ -1964,6 +1967,9 @@ public class Nd4j {
|
||||||
*/
|
*/
|
||||||
public static INDArray linspace(@NonNull DataType dataType, double lower, double step, long num) {
|
public static INDArray linspace(@NonNull DataType dataType, double lower, double step, long num) {
|
||||||
Preconditions.checkState(dataType.isFPType());
|
Preconditions.checkState(dataType.isFPType());
|
||||||
|
if (num == 1)
|
||||||
|
return Nd4j.scalar(dataType, lower);
|
||||||
|
|
||||||
return Nd4j.getExecutioner().exec(new Linspace(lower, num, step, dataType));
|
return Nd4j.getExecutioner().exec(new Linspace(lower, num, step, dataType));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1977,10 +1983,15 @@ public class Nd4j {
|
||||||
*/
|
*/
|
||||||
public static INDArray linspace( double lower, double upper, long num, @NonNull DataType dataType) {
|
public static INDArray linspace( double lower, double upper, long num, @NonNull DataType dataType) {
|
||||||
Preconditions.checkState(dataType.isFPType());
|
Preconditions.checkState(dataType.isFPType());
|
||||||
|
if (num == 1)
|
||||||
|
return Nd4j.scalar(dataType, lower);
|
||||||
|
|
||||||
return Nd4j.getExecutioner().exec(new Linspace(lower, upper, num, dataType));
|
return Nd4j.getExecutioner().exec(new Linspace(lower, upper, num, dataType));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static INDArray linspaceWithCustomOp(long lower, long upper, int num, DataType dataType) {
|
private static INDArray linspaceWithCustomOp(long lower, long upper, int num, DataType dataType) {
|
||||||
|
if (num == 1)
|
||||||
|
return Nd4j.scalar(dataType, lower);
|
||||||
|
|
||||||
INDArray result = Nd4j.createUninitialized(dataType, new long[] {num}, Nd4j.order());
|
INDArray result = Nd4j.createUninitialized(dataType, new long[] {num}, Nd4j.order());
|
||||||
|
|
||||||
|
@ -1994,6 +2005,8 @@ public class Nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
private static INDArray linspaceWithCustomOpByRange(long lower, long upper, long num, long step, DataType dataType) {
|
private static INDArray linspaceWithCustomOpByRange(long lower, long upper, long num, long step, DataType dataType) {
|
||||||
|
if (num == 1)
|
||||||
|
return Nd4j.scalar(dataType, lower);
|
||||||
|
|
||||||
INDArray result = Nd4j.createUninitialized(dataType, new long[] {num}, Nd4j.order());
|
INDArray result = Nd4j.createUninitialized(dataType, new long[] {num}, Nd4j.order());
|
||||||
|
|
||||||
|
|
|
@ -1683,4 +1683,12 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
val z = Nd4j.exec(new OnesLike(x, DataType.INT32))[0];
|
val z = Nd4j.exec(new OnesLike(x, DataType.INT32))[0];
|
||||||
assertEquals(e, z);
|
assertEquals(e, z);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testLinSpaceEdge_1() {
|
||||||
|
val x = Nd4j.linspace(1,10,1, DataType.FLOAT);
|
||||||
|
val e = Nd4j.scalar(1.0f);
|
||||||
|
|
||||||
|
assertEquals(e, x);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue