Lin_space operation improve (#373)
* libnd4j update linspace op Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j #8513 update lin_space op, tests added Signed-off-by: Oleg <oleg.semeniv@gmail.com> * - minor linspace tweaks (num_elements now iArg) - java linspace updates - couple of additional tests for linspace Signed-off-by: raver119 <raver119@gmail.com> * roll back timeout change Signed-off-by: raver119 <raver119@gmail.com> Co-authored-by: raver119 <raver119@gmail.com>master
parent
12ba1fa406
commit
3d15706ffa
|
@ -26,24 +26,38 @@
|
|||
namespace sd {
|
||||
namespace ops {
|
||||
|
||||
CUSTOM_OP_IMPL(lin_space, 3, 1, false, 0, 0) {
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
auto start = INPUT_VARIABLE(0);
|
||||
auto finish = INPUT_VARIABLE(1);
|
||||
auto numOfElements = INPUT_VARIABLE(2);
|
||||
CUSTOM_OP_IMPL(lin_space, 0, 1, false, 0, 0) {
|
||||
|
||||
if (numOfElements->e<Nd4jLong>(0) == 1) {
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
const int nInputs = block.width();
|
||||
bool bInputs = (3 == nInputs || 3 == block.numI() || (2 == block.numT() && block.numI() > 0));
|
||||
|
||||
REQUIRE_TRUE(bInputs, 0, "lin_space OP: Have to be supplied correct inputs, input size or T_ARG size have to be equal 3, but got inputs - %i, T_ARGS - %i!", nInputs, block.numT());
|
||||
|
||||
auto start = (nInputs > 0) ? INPUT_VARIABLE(0)->e<double>(0) : static_cast<double>(T_ARG(0));
|
||||
auto finish = (nInputs > 0) ? INPUT_VARIABLE(1)->e<double>(0) : static_cast<double>(T_ARG(1));
|
||||
auto numOfElements = (nInputs > 0) ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : static_cast<Nd4jLong>(I_ARG(0));
|
||||
|
||||
if (numOfElements == 1) {
|
||||
output->assign(start);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
output->linspace(start->e<double>(0), (finish->e<double>(0) - start->e<double>(0)) / (numOfElements->e<Nd4jLong>(0) - 1.));
|
||||
output->linspace(start, (finish - start) / ( numOfElements - 1.0 ));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(lin_space) {
|
||||
auto dataType = ArrayOptions::dataType(inputShape->at(0));
|
||||
Nd4jLong steps = INPUT_VARIABLE(2)->e<Nd4jLong>(0);
|
||||
|
||||
const int nInputs = block.width();
|
||||
bool bInputs = (3 == nInputs || 3 == block.numI() || (2 == block.numT() && block.numI() > 0));
|
||||
REQUIRE_TRUE(bInputs, 0, "lin_space OP: Have to be supplied correct inputs, input size or T_ARG size have to be equal 3, but got inputs - %i, T_ARGS - %i!", nInputs, block.numT() );
|
||||
|
||||
|
||||
auto dataType = (nInputs > 0) ? ArrayOptions::dataType(inputShape->at(0)) : ( block.numD() > 0 ? static_cast<DataType>(D_ARG(0)) : DataType::FLOAT32) ;
|
||||
Nd4jLong steps = (nInputs > 0) ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : static_cast<Nd4jLong>(I_ARG(0));
|
||||
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(steps, dataType));
|
||||
}
|
||||
|
||||
|
|
|
@ -1433,16 +1433,20 @@ namespace sd {
|
|||
/**
|
||||
* lin_space - op porting from TF (https://www.tensorflow.org/api_docs/python/tf/lin_space)
|
||||
*
|
||||
* input params:
|
||||
* optional input params:
|
||||
* 0 - startVal - NDArray scalar (float point)
|
||||
* 1 - finishVal - NDArray scalar (float point)
|
||||
* 2 - numOfElements - NDArray scalar (integer)
|
||||
*
|
||||
* Optional:
|
||||
* T args
|
||||
* 0 - startVal
|
||||
* 1 - finishVal]
|
||||
* 2 - numOfElements
|
||||
* output:
|
||||
* 0 - 1D NDArray with the same type as input and length as given with numOfElements param.
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_lin_space)
|
||||
DECLARE_CUSTOM_OP(lin_space, 3, 1, false, 0, 0);
|
||||
DECLARE_CUSTOM_OP(lin_space, 0, 1, false, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
|
|
|
@ -2010,6 +2010,34 @@ TEST_F(DeclarableOpsTests10, LinSpace_Test1) {
|
|||
|
||||
ASSERT_TRUE(expect.equalsTo(res));
|
||||
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests10, LinSpace_Test2) {
|
||||
|
||||
NDArray expect = NDArrayFactory::create<float>({1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6., 6.5, 7., 7.5,
|
||||
8., 8.5, 9., 9.5, 10., 10.5, 11., 11.5, 12.});
|
||||
|
||||
sd::ops::lin_space op;
|
||||
auto result = op.evaluate({}, {1, 12}, {23});
|
||||
ASSERT_EQ(result.status(), ND4J_STATUS_OK);
|
||||
auto res = result.at(0);
|
||||
ASSERT_EQ( res->dataType(), sd::DataType::FLOAT32 );
|
||||
ASSERT_TRUE(expect.equalsTo(res));
|
||||
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests10, LinSpace_Test3) {
|
||||
|
||||
NDArray expect('c', { 23 }, {1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6., 6.5, 7., 7.5, 8., 8.5, 9., 9.5, 10., 10.5, 11., 11.5, 12.}, sd::DataType::DOUBLE );
|
||||
|
||||
sd::ops::lin_space op;
|
||||
auto result = op.evaluate({}, {1, 12}, {23}, {}, { sd::DOUBLE });
|
||||
ASSERT_EQ(result.status(), ND4J_STATUS_OK);
|
||||
auto res = result.at(0);
|
||||
|
||||
ASSERT_EQ( res->dataType(), expect.dataType());
|
||||
ASSERT_TRUE(expect.equalsTo(res));
|
||||
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) {
|
||||
|
|
|
@ -1334,6 +1334,20 @@ TEST_F(JavaInteropTests, test_workspace_backed_arrays_1) {
|
|||
ASSERT_EQ(Status::OK(), status);
|
||||
}
|
||||
|
||||
TEST_F(JavaInteropTests, test_linspace_shape_1) {
|
||||
if (!Environment::getInstance()->isCPU())
|
||||
return;
|
||||
|
||||
sd::ops::lin_space op;
|
||||
double tArgs[2] = {1.0, 10.0};
|
||||
Nd4jLong iArgs = 10L;
|
||||
int dArg = (int) sd::DataType::FLOAT32;
|
||||
auto result = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 2, &iArgs, 1, nullptr, 0, &dArg, 1);
|
||||
|
||||
ASSERT_EQ(1, result->size());
|
||||
delete result;
|
||||
}
|
||||
|
||||
/*
|
||||
TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
|
||||
auto pl = sd::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb");
|
||||
|
|
|
@ -42,6 +42,9 @@ import java.util.Map;
|
|||
public class Linspace extends DynamicCustomOp {
|
||||
|
||||
private DataType dataType;
|
||||
private double start;
|
||||
private double stop;
|
||||
private long elements;
|
||||
|
||||
public Linspace(SameDiff sameDiff, DataType dataType, double start, double stop, long number) {
|
||||
this(sameDiff, sameDiff.constant(start), sameDiff.constant(stop), sameDiff.constant(number), dataType);
|
||||
|
@ -54,7 +57,7 @@ public class Linspace extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
public Linspace(DataType dataType, double start, double stop, long number) {
|
||||
this(dataType, Nd4j.scalar(start), Nd4j.scalar(stop), Nd4j.scalar(number));
|
||||
this(start, stop, number, dataType);
|
||||
}
|
||||
|
||||
public Linspace(DataType dataType, INDArray start, INDArray stop, INDArray number) {
|
||||
|
@ -67,6 +70,19 @@ public class Linspace extends DynamicCustomOp {
|
|||
addDArgument(dataType);
|
||||
}
|
||||
|
||||
public Linspace(double start, double stop, long number, @NonNull DataType dataType) {
|
||||
super(new INDArray[]{}, null);
|
||||
this.dataType = dataType;
|
||||
addDArgument(dataType);
|
||||
|
||||
this.start = start;
|
||||
this.stop = stop;
|
||||
this.elements = number;
|
||||
|
||||
addTArgument(this.start, this.stop);
|
||||
addIArgument(elements);
|
||||
}
|
||||
|
||||
public Linspace(){ }
|
||||
|
||||
@Override
|
||||
|
|
|
@ -1947,7 +1947,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
val result = new ArrayList<LongShapeDescriptor>();
|
||||
int nIn = opContext != null ? opContext.numInputArguments() : op.numInputArguments();
|
||||
if(nIn == 0 && op.getDescriptor().getNumInputs() != -2) {
|
||||
if(nIn == 0 && op.getDescriptor().getNumInputs() >= 1) {
|
||||
if(log.isTraceEnabled()){
|
||||
log.trace("Could not calculate output shape for op {}: number of input args was 0",
|
||||
op.getClass().getName());
|
||||
|
|
|
@ -1754,7 +1754,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
val result = new ArrayList<LongShapeDescriptor>();
|
||||
int nIn = opContext != null ? opContext.numInputArguments() : op.numInputArguments();
|
||||
if(nIn == 0 && op.getDescriptor().getNumInputs() != -2) {
|
||||
if(nIn == 0 && op.getDescriptor().getNumInputs() >= 1) {
|
||||
if(log.isTraceEnabled()){
|
||||
log.trace("Could not calculate output shape for op {}: number of input args was 0",
|
||||
op.getClass().getName());
|
||||
|
|
|
@ -20475,11 +20475,15 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
/**
|
||||
* lin_space - op porting from TF (https://www.tensorflow.org/api_docs/python/tf/lin_space)
|
||||
*
|
||||
* input params:
|
||||
* optional input params:
|
||||
* 0 - startVal - NDArray scalar (float point)
|
||||
* 1 - finishVal - NDArray scalar (float point)
|
||||
* 2 - numOfElements - NDArray scalar (integer)
|
||||
*
|
||||
* Optional:
|
||||
* T args
|
||||
* 0 - startVal
|
||||
* 1 - finishVal]
|
||||
* 2 - numOfElements
|
||||
* output:
|
||||
* 0 - 1D NDArray with the same type as input and length as given with numOfElements param.
|
||||
*/
|
||||
|
|
|
@ -37,6 +37,7 @@ import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear;
|
|||
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
|
||||
import org.nd4j.linalg.api.ops.impl.shape.Create;
|
||||
import org.nd4j.linalg.api.ops.impl.shape.Linspace;
|
||||
import org.nd4j.linalg.api.ops.impl.shape.OnesLike;
|
||||
import org.nd4j.linalg.api.ops.impl.shape.SequenceMask;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.Cholesky;
|
||||
|
@ -1803,6 +1804,16 @@ public class CustomOpsTests extends BaseNd4jTest {
|
|||
assertEquals(ret[0], in);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testLinspaceSignature_1() throws Exception {
|
||||
val array1 = Nd4j.exec(new Linspace(DataType.FLOAT, Nd4j.scalar(1.0f), Nd4j.scalar(10.f), Nd4j.scalar(10L)))[0];
|
||||
val array2 = Nd4j.exec(new Linspace(DataType.FLOAT, 1.0f, 10.f, 10L))[0];
|
||||
|
||||
assertEquals(array1.dataType(), array2.dataType());
|
||||
assertEquals(array1, array2);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLogdet() {
|
||||
INDArray x = Nd4j.createFromArray(new double[]{
|
||||
|
|
Loading…
Reference in New Issue