Lin_space operation improve (#373)
* libnd4j update linspace op Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j #8513 update lin_space op, tests added Signed-off-by: Oleg <oleg.semeniv@gmail.com> * - minor linspace tweaks (num_elements now iArg) - java linspace updates - couple of additional tests for linspace Signed-off-by: raver119 <raver119@gmail.com> * roll back timeout change Signed-off-by: raver119 <raver119@gmail.com> Co-authored-by: raver119 <raver119@gmail.com>master
parent
12ba1fa406
commit
3d15706ffa
|
@ -26,24 +26,38 @@
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
CUSTOM_OP_IMPL(lin_space, 3, 1, false, 0, 0) {
|
CUSTOM_OP_IMPL(lin_space, 0, 1, false, 0, 0) {
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
|
||||||
auto start = INPUT_VARIABLE(0);
|
|
||||||
auto finish = INPUT_VARIABLE(1);
|
|
||||||
auto numOfElements = INPUT_VARIABLE(2);
|
|
||||||
|
|
||||||
if (numOfElements->e<Nd4jLong>(0) == 1) {
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
const int nInputs = block.width();
|
||||||
|
bool bInputs = (3 == nInputs || 3 == block.numI() || (2 == block.numT() && block.numI() > 0));
|
||||||
|
|
||||||
|
REQUIRE_TRUE(bInputs, 0, "lin_space OP: Have to be supplied correct inputs, input size or T_ARG size have to be equal 3, but got inputs - %i, T_ARGS - %i!", nInputs, block.numT());
|
||||||
|
|
||||||
|
auto start = (nInputs > 0) ? INPUT_VARIABLE(0)->e<double>(0) : static_cast<double>(T_ARG(0));
|
||||||
|
auto finish = (nInputs > 0) ? INPUT_VARIABLE(1)->e<double>(0) : static_cast<double>(T_ARG(1));
|
||||||
|
auto numOfElements = (nInputs > 0) ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : static_cast<Nd4jLong>(I_ARG(0));
|
||||||
|
|
||||||
|
if (numOfElements == 1) {
|
||||||
output->assign(start);
|
output->assign(start);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
output->linspace(start->e<double>(0), (finish->e<double>(0) - start->e<double>(0)) / (numOfElements->e<Nd4jLong>(0) - 1.));
|
output->linspace(start, (finish - start) / ( numOfElements - 1.0 ));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(lin_space) {
|
DECLARE_SHAPE_FN(lin_space) {
|
||||||
auto dataType = ArrayOptions::dataType(inputShape->at(0));
|
|
||||||
Nd4jLong steps = INPUT_VARIABLE(2)->e<Nd4jLong>(0);
|
const int nInputs = block.width();
|
||||||
|
bool bInputs = (3 == nInputs || 3 == block.numI() || (2 == block.numT() && block.numI() > 0));
|
||||||
|
REQUIRE_TRUE(bInputs, 0, "lin_space OP: Have to be supplied correct inputs, input size or T_ARG size have to be equal 3, but got inputs - %i, T_ARGS - %i!", nInputs, block.numT() );
|
||||||
|
|
||||||
|
|
||||||
|
auto dataType = (nInputs > 0) ? ArrayOptions::dataType(inputShape->at(0)) : ( block.numD() > 0 ? static_cast<DataType>(D_ARG(0)) : DataType::FLOAT32) ;
|
||||||
|
Nd4jLong steps = (nInputs > 0) ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : static_cast<Nd4jLong>(I_ARG(0));
|
||||||
|
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(steps, dataType));
|
return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(steps, dataType));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1433,16 +1433,20 @@ namespace sd {
|
||||||
/**
|
/**
|
||||||
* lin_space - op porting from TF (https://www.tensorflow.org/api_docs/python/tf/lin_space)
|
* lin_space - op porting from TF (https://www.tensorflow.org/api_docs/python/tf/lin_space)
|
||||||
*
|
*
|
||||||
* input params:
|
* optional input params:
|
||||||
* 0 - startVal - NDArray scalar (float point)
|
* 0 - startVal - NDArray scalar (float point)
|
||||||
* 1 - finishVal - NDArray scalar (float point)
|
* 1 - finishVal - NDArray scalar (float point)
|
||||||
* 2 - numOfElements - NDArray scalar (integer)
|
* 2 - numOfElements - NDArray scalar (integer)
|
||||||
*
|
* Optional:
|
||||||
|
* T args
|
||||||
|
* 0 - startVal
|
||||||
|
* 1 - finishVal]
|
||||||
|
* 2 - numOfElements
|
||||||
* output:
|
* output:
|
||||||
* 0 - 1D NDArray with the same type as input and length as given with numOfElements param.
|
* 0 - 1D NDArray with the same type as input and length as given with numOfElements param.
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_lin_space)
|
#if NOT_EXCLUDED(OP_lin_space)
|
||||||
DECLARE_CUSTOM_OP(lin_space, 3, 1, false, 0, 0);
|
DECLARE_CUSTOM_OP(lin_space, 0, 1, false, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -2010,6 +2010,34 @@ TEST_F(DeclarableOpsTests10, LinSpace_Test1) {
|
||||||
|
|
||||||
ASSERT_TRUE(expect.equalsTo(res));
|
ASSERT_TRUE(expect.equalsTo(res));
|
||||||
|
|
||||||
|
}
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests10, LinSpace_Test2) {
|
||||||
|
|
||||||
|
NDArray expect = NDArrayFactory::create<float>({1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6., 6.5, 7., 7.5,
|
||||||
|
8., 8.5, 9., 9.5, 10., 10.5, 11., 11.5, 12.});
|
||||||
|
|
||||||
|
sd::ops::lin_space op;
|
||||||
|
auto result = op.evaluate({}, {1, 12}, {23});
|
||||||
|
ASSERT_EQ(result.status(), ND4J_STATUS_OK);
|
||||||
|
auto res = result.at(0);
|
||||||
|
ASSERT_EQ( res->dataType(), sd::DataType::FLOAT32 );
|
||||||
|
ASSERT_TRUE(expect.equalsTo(res));
|
||||||
|
|
||||||
|
}
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests10, LinSpace_Test3) {
|
||||||
|
|
||||||
|
NDArray expect('c', { 23 }, {1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6., 6.5, 7., 7.5, 8., 8.5, 9., 9.5, 10., 10.5, 11., 11.5, 12.}, sd::DataType::DOUBLE );
|
||||||
|
|
||||||
|
sd::ops::lin_space op;
|
||||||
|
auto result = op.evaluate({}, {1, 12}, {23}, {}, { sd::DOUBLE });
|
||||||
|
ASSERT_EQ(result.status(), ND4J_STATUS_OK);
|
||||||
|
auto res = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_EQ( res->dataType(), expect.dataType());
|
||||||
|
ASSERT_TRUE(expect.equalsTo(res));
|
||||||
|
|
||||||
}
|
}
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) {
|
TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) {
|
||||||
|
|
|
@ -1334,6 +1334,20 @@ TEST_F(JavaInteropTests, test_workspace_backed_arrays_1) {
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(JavaInteropTests, test_linspace_shape_1) {
|
||||||
|
if (!Environment::getInstance()->isCPU())
|
||||||
|
return;
|
||||||
|
|
||||||
|
sd::ops::lin_space op;
|
||||||
|
double tArgs[2] = {1.0, 10.0};
|
||||||
|
Nd4jLong iArgs = 10L;
|
||||||
|
int dArg = (int) sd::DataType::FLOAT32;
|
||||||
|
auto result = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 2, &iArgs, 1, nullptr, 0, &dArg, 1);
|
||||||
|
|
||||||
|
ASSERT_EQ(1, result->size());
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
|
TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
|
||||||
auto pl = sd::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb");
|
auto pl = sd::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb");
|
||||||
|
|
|
@ -42,6 +42,9 @@ import java.util.Map;
|
||||||
public class Linspace extends DynamicCustomOp {
|
public class Linspace extends DynamicCustomOp {
|
||||||
|
|
||||||
private DataType dataType;
|
private DataType dataType;
|
||||||
|
private double start;
|
||||||
|
private double stop;
|
||||||
|
private long elements;
|
||||||
|
|
||||||
public Linspace(SameDiff sameDiff, DataType dataType, double start, double stop, long number) {
|
public Linspace(SameDiff sameDiff, DataType dataType, double start, double stop, long number) {
|
||||||
this(sameDiff, sameDiff.constant(start), sameDiff.constant(stop), sameDiff.constant(number), dataType);
|
this(sameDiff, sameDiff.constant(start), sameDiff.constant(stop), sameDiff.constant(number), dataType);
|
||||||
|
@ -54,7 +57,7 @@ public class Linspace extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
public Linspace(DataType dataType, double start, double stop, long number) {
|
public Linspace(DataType dataType, double start, double stop, long number) {
|
||||||
this(dataType, Nd4j.scalar(start), Nd4j.scalar(stop), Nd4j.scalar(number));
|
this(start, stop, number, dataType);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Linspace(DataType dataType, INDArray start, INDArray stop, INDArray number) {
|
public Linspace(DataType dataType, INDArray start, INDArray stop, INDArray number) {
|
||||||
|
@ -67,6 +70,19 @@ public class Linspace extends DynamicCustomOp {
|
||||||
addDArgument(dataType);
|
addDArgument(dataType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Linspace(double start, double stop, long number, @NonNull DataType dataType) {
|
||||||
|
super(new INDArray[]{}, null);
|
||||||
|
this.dataType = dataType;
|
||||||
|
addDArgument(dataType);
|
||||||
|
|
||||||
|
this.start = start;
|
||||||
|
this.stop = stop;
|
||||||
|
this.elements = number;
|
||||||
|
|
||||||
|
addTArgument(this.start, this.stop);
|
||||||
|
addIArgument(elements);
|
||||||
|
}
|
||||||
|
|
||||||
public Linspace(){ }
|
public Linspace(){ }
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -1947,7 +1947,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
val result = new ArrayList<LongShapeDescriptor>();
|
val result = new ArrayList<LongShapeDescriptor>();
|
||||||
int nIn = opContext != null ? opContext.numInputArguments() : op.numInputArguments();
|
int nIn = opContext != null ? opContext.numInputArguments() : op.numInputArguments();
|
||||||
if(nIn == 0 && op.getDescriptor().getNumInputs() != -2) {
|
if(nIn == 0 && op.getDescriptor().getNumInputs() >= 1) {
|
||||||
if(log.isTraceEnabled()){
|
if(log.isTraceEnabled()){
|
||||||
log.trace("Could not calculate output shape for op {}: number of input args was 0",
|
log.trace("Could not calculate output shape for op {}: number of input args was 0",
|
||||||
op.getClass().getName());
|
op.getClass().getName());
|
||||||
|
|
|
@ -1754,7 +1754,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
val result = new ArrayList<LongShapeDescriptor>();
|
val result = new ArrayList<LongShapeDescriptor>();
|
||||||
int nIn = opContext != null ? opContext.numInputArguments() : op.numInputArguments();
|
int nIn = opContext != null ? opContext.numInputArguments() : op.numInputArguments();
|
||||||
if(nIn == 0 && op.getDescriptor().getNumInputs() != -2) {
|
if(nIn == 0 && op.getDescriptor().getNumInputs() >= 1) {
|
||||||
if(log.isTraceEnabled()){
|
if(log.isTraceEnabled()){
|
||||||
log.trace("Could not calculate output shape for op {}: number of input args was 0",
|
log.trace("Could not calculate output shape for op {}: number of input args was 0",
|
||||||
op.getClass().getName());
|
op.getClass().getName());
|
||||||
|
|
|
@ -20475,11 +20475,15 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
/**
|
/**
|
||||||
* lin_space - op porting from TF (https://www.tensorflow.org/api_docs/python/tf/lin_space)
|
* lin_space - op porting from TF (https://www.tensorflow.org/api_docs/python/tf/lin_space)
|
||||||
*
|
*
|
||||||
* input params:
|
* optional input params:
|
||||||
* 0 - startVal - NDArray scalar (float point)
|
* 0 - startVal - NDArray scalar (float point)
|
||||||
* 1 - finishVal - NDArray scalar (float point)
|
* 1 - finishVal - NDArray scalar (float point)
|
||||||
* 2 - numOfElements - NDArray scalar (integer)
|
* 2 - numOfElements - NDArray scalar (integer)
|
||||||
*
|
* Optional:
|
||||||
|
* T args
|
||||||
|
* 0 - startVal
|
||||||
|
* 1 - finishVal]
|
||||||
|
* 2 - numOfElements
|
||||||
* output:
|
* output:
|
||||||
* 0 - 1D NDArray with the same type as input and length as given with numOfElements param.
|
* 0 - 1D NDArray with the same type as input and length as given with numOfElements param.
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -37,6 +37,7 @@ import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
|
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
|
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.Create;
|
import org.nd4j.linalg.api.ops.impl.shape.Create;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.shape.Linspace;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.OnesLike;
|
import org.nd4j.linalg.api.ops.impl.shape.OnesLike;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.SequenceMask;
|
import org.nd4j.linalg.api.ops.impl.shape.SequenceMask;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.Cholesky;
|
import org.nd4j.linalg.api.ops.impl.transforms.Cholesky;
|
||||||
|
@ -1803,6 +1804,16 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
assertEquals(ret[0], in);
|
assertEquals(ret[0], in);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testLinspaceSignature_1() throws Exception {
|
||||||
|
val array1 = Nd4j.exec(new Linspace(DataType.FLOAT, Nd4j.scalar(1.0f), Nd4j.scalar(10.f), Nd4j.scalar(10L)))[0];
|
||||||
|
val array2 = Nd4j.exec(new Linspace(DataType.FLOAT, 1.0f, 10.f, 10L))[0];
|
||||||
|
|
||||||
|
assertEquals(array1.dataType(), array2.dataType());
|
||||||
|
assertEquals(array1, array2);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testLogdet() {
|
public void testLogdet() {
|
||||||
INDArray x = Nd4j.createFromArray(new double[]{
|
INDArray x = Nd4j.createFromArray(new double[]{
|
||||||
|
|
Loading…
Reference in New Issue