Lin_space operation improve (#373)

* libnd4j update linspace op

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j #8513 update lin_space op, tests added

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* - minor linspace tweaks (num_elements now iArg)
- java linspace updates
- couple of additional tests for linspace

Signed-off-by: raver119 <raver119@gmail.com>

* roll back timeout change

Signed-off-by: raver119 <raver119@gmail.com>

Co-authored-by: raver119 <raver119@gmail.com>
master
Oleh 2020-04-16 14:53:56 +03:00 committed by GitHub
parent 12ba1fa406
commit 3d15706ffa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 108 additions and 17 deletions

View File

@ -26,24 +26,38 @@
namespace sd {
namespace ops {
CUSTOM_OP_IMPL(lin_space, 3, 1, false, 0, 0) {
auto output = OUTPUT_VARIABLE(0);
auto start = INPUT_VARIABLE(0);
auto finish = INPUT_VARIABLE(1);
auto numOfElements = INPUT_VARIABLE(2);
CUSTOM_OP_IMPL(lin_space, 0, 1, false, 0, 0) {
if (numOfElements->e<Nd4jLong>(0) == 1) {
auto output = OUTPUT_VARIABLE(0);
const int nInputs = block.width();
bool bInputs = (3 == nInputs || 3 == block.numI() || (2 == block.numT() && block.numI() > 0));
REQUIRE_TRUE(bInputs, 0, "lin_space OP: Have to be supplied correct inputs, input size or T_ARG size have to be equal 3, but got inputs - %i, T_ARGS - %i!", nInputs, block.numT());
auto start = (nInputs > 0) ? INPUT_VARIABLE(0)->e<double>(0) : static_cast<double>(T_ARG(0));
auto finish = (nInputs > 0) ? INPUT_VARIABLE(1)->e<double>(0) : static_cast<double>(T_ARG(1));
auto numOfElements = (nInputs > 0) ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : static_cast<Nd4jLong>(I_ARG(0));
if (numOfElements == 1) {
output->assign(start);
return Status::OK();
}
output->linspace(start->e<double>(0), (finish->e<double>(0) - start->e<double>(0)) / (numOfElements->e<Nd4jLong>(0) - 1.));
output->linspace(start, (finish - start) / ( numOfElements - 1.0 ));
return Status::OK();
}
DECLARE_SHAPE_FN(lin_space) {
auto dataType = ArrayOptions::dataType(inputShape->at(0));
Nd4jLong steps = INPUT_VARIABLE(2)->e<Nd4jLong>(0);
const int nInputs = block.width();
bool bInputs = (3 == nInputs || 3 == block.numI() || (2 == block.numT() && block.numI() > 0));
REQUIRE_TRUE(bInputs, 0, "lin_space OP: Have to be supplied correct inputs, input size or T_ARG size have to be equal 3, but got inputs - %i, T_ARGS - %i!", nInputs, block.numT() );
auto dataType = (nInputs > 0) ? ArrayOptions::dataType(inputShape->at(0)) : ( block.numD() > 0 ? static_cast<DataType>(D_ARG(0)) : DataType::FLOAT32) ;
Nd4jLong steps = (nInputs > 0) ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : static_cast<Nd4jLong>(I_ARG(0));
return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(steps, dataType));
}

View File

@ -1433,16 +1433,20 @@ namespace sd {
/**
* lin_space - op porting from TF (https://www.tensorflow.org/api_docs/python/tf/lin_space)
*
* input params:
* optional input params:
* 0 - startVal - NDArray scalar (float point)
* 1 - finishVal - NDArray scalar (float point)
* 2 - numOfElements - NDArray scalar (integer)
*
* Optional:
* T args
* 0 - startVal
* 1 - finishVal]
* 2 - numOfElements
* output:
* 0 - 1D NDArray with the same type as input and length as given with numOfElements param.
*/
#if NOT_EXCLUDED(OP_lin_space)
DECLARE_CUSTOM_OP(lin_space, 3, 1, false, 0, 0);
DECLARE_CUSTOM_OP(lin_space, 0, 1, false, 0, 0);
#endif
/**

View File

@ -2010,6 +2010,34 @@ TEST_F(DeclarableOpsTests10, LinSpace_Test1) {
ASSERT_TRUE(expect.equalsTo(res));
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, LinSpace_Test2) {
NDArray expect = NDArrayFactory::create<float>({1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6., 6.5, 7., 7.5,
8., 8.5, 9., 9.5, 10., 10.5, 11., 11.5, 12.});
sd::ops::lin_space op;
auto result = op.evaluate({}, {1, 12}, {23});
ASSERT_EQ(result.status(), ND4J_STATUS_OK);
auto res = result.at(0);
ASSERT_EQ( res->dataType(), sd::DataType::FLOAT32 );
ASSERT_TRUE(expect.equalsTo(res));
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, LinSpace_Test3) {
NDArray expect('c', { 23 }, {1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6., 6.5, 7., 7.5, 8., 8.5, 9., 9.5, 10., 10.5, 11., 11.5, 12.}, sd::DataType::DOUBLE );
sd::ops::lin_space op;
auto result = op.evaluate({}, {1, 12}, {23}, {}, { sd::DOUBLE });
ASSERT_EQ(result.status(), ND4J_STATUS_OK);
auto res = result.at(0);
ASSERT_EQ( res->dataType(), expect.dataType());
ASSERT_TRUE(expect.equalsTo(res));
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) {

View File

@ -1334,6 +1334,20 @@ TEST_F(JavaInteropTests, test_workspace_backed_arrays_1) {
ASSERT_EQ(Status::OK(), status);
}
TEST_F(JavaInteropTests, test_linspace_shape_1) {
if (!Environment::getInstance()->isCPU())
return;
sd::ops::lin_space op;
double tArgs[2] = {1.0, 10.0};
Nd4jLong iArgs = 10L;
int dArg = (int) sd::DataType::FLOAT32;
auto result = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 2, &iArgs, 1, nullptr, 0, &dArg, 1);
ASSERT_EQ(1, result->size());
delete result;
}
/*
TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
auto pl = sd::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb");

View File

@ -42,6 +42,9 @@ import java.util.Map;
public class Linspace extends DynamicCustomOp {
private DataType dataType;
private double start;
private double stop;
private long elements;
public Linspace(SameDiff sameDiff, DataType dataType, double start, double stop, long number) {
this(sameDiff, sameDiff.constant(start), sameDiff.constant(stop), sameDiff.constant(number), dataType);
@ -54,7 +57,7 @@ public class Linspace extends DynamicCustomOp {
}
public Linspace(DataType dataType, double start, double stop, long number) {
this(dataType, Nd4j.scalar(start), Nd4j.scalar(stop), Nd4j.scalar(number));
this(start, stop, number, dataType);
}
public Linspace(DataType dataType, INDArray start, INDArray stop, INDArray number) {
@ -67,6 +70,19 @@ public class Linspace extends DynamicCustomOp {
addDArgument(dataType);
}
public Linspace(double start, double stop, long number, @NonNull DataType dataType) {
super(new INDArray[]{}, null);
this.dataType = dataType;
addDArgument(dataType);
this.start = start;
this.stop = stop;
this.elements = number;
addTArgument(this.start, this.stop);
addIArgument(elements);
}
public Linspace(){ }
@Override

View File

@ -1947,7 +1947,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
val result = new ArrayList<LongShapeDescriptor>();
int nIn = opContext != null ? opContext.numInputArguments() : op.numInputArguments();
if(nIn == 0 && op.getDescriptor().getNumInputs() != -2) {
if(nIn == 0 && op.getDescriptor().getNumInputs() >= 1) {
if(log.isTraceEnabled()){
log.trace("Could not calculate output shape for op {}: number of input args was 0",
op.getClass().getName());

View File

@ -1754,7 +1754,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
val result = new ArrayList<LongShapeDescriptor>();
int nIn = opContext != null ? opContext.numInputArguments() : op.numInputArguments();
if(nIn == 0 && op.getDescriptor().getNumInputs() != -2) {
if(nIn == 0 && op.getDescriptor().getNumInputs() >= 1) {
if(log.isTraceEnabled()){
log.trace("Could not calculate output shape for op {}: number of input args was 0",
op.getClass().getName());

View File

@ -20475,11 +20475,15 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
/**
* lin_space - op porting from TF (https://www.tensorflow.org/api_docs/python/tf/lin_space)
*
* input params:
* optional input params:
* 0 - startVal - NDArray scalar (float point)
* 1 - finishVal - NDArray scalar (float point)
* 2 - numOfElements - NDArray scalar (integer)
*
* Optional:
* T args
* 0 - startVal
* 1 - finishVal]
* 2 - numOfElements
* output:
* 0 - 1D NDArray with the same type as input and length as given with numOfElements param.
*/

View File

@ -37,6 +37,7 @@ import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear;
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
import org.nd4j.linalg.api.ops.impl.shape.Create;
import org.nd4j.linalg.api.ops.impl.shape.Linspace;
import org.nd4j.linalg.api.ops.impl.shape.OnesLike;
import org.nd4j.linalg.api.ops.impl.shape.SequenceMask;
import org.nd4j.linalg.api.ops.impl.transforms.Cholesky;
@ -1803,6 +1804,16 @@ public class CustomOpsTests extends BaseNd4jTest {
assertEquals(ret[0], in);
}
@Test
public void testLinspaceSignature_1() throws Exception {
val array1 = Nd4j.exec(new Linspace(DataType.FLOAT, Nd4j.scalar(1.0f), Nd4j.scalar(10.f), Nd4j.scalar(10L)))[0];
val array2 = Nd4j.exec(new Linspace(DataType.FLOAT, 1.0f, 10.f, 10L))[0];
assertEquals(array1.dataType(), array2.dataType());
assertEquals(array1, array2);
}
@Test
public void testLogdet() {
INDArray x = Nd4j.createFromArray(new double[]{