Few minor fixes (#381)

* - 1D indexing fix
- couple of new tests for 1D indexing

Signed-off-by: raver119 <raver119@gmail.com>

* percentile fix + test

Signed-off-by: raver119 <raver119@gmail.com>

* wrong signature used in test

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2020-04-16 13:25:13 +03:00 committed by GitHub
parent 4247718f61
commit 12ba1fa406
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 116 additions and 4 deletions

View File

@ -2045,8 +2045,18 @@ public abstract class BaseNDArray implements INDArray, Iterable {
throw new ND4JIllegalArgumentException("Indices must be a vector or matrix."); throw new ND4JIllegalArgumentException("Indices must be a vector or matrix.");
} }
if(indices.rows() == rank()) { if (rank() == 1) {
INDArray ret = Nd4j.create(indices.dataType(), indices.columns()); Preconditions.checkArgument(indices.rank() <= 1, "For 1D vector indices must be either scalar or vector as well");
val ret = Nd4j.createUninitialized(this.dataType(), indices.length());
for (int e = 0; e < indices.length(); e++) {
val idx = indices.getLong(e);
val value = getDouble(idx);
ret.putScalar(e, value);
}
return ret;
} else if(indices.rows() == rank()) {
INDArray ret = Nd4j.create(this.dataType(), indices.columns());
for(int i = 0; i < indices.columns(); i++) { for(int i = 0; i < indices.columns(); i++) {
int[] specifiedIndex = indices.getColumn(i).dup().data().asInt(); int[] specifiedIndex = indices.getColumn(i).dup().data().asInt();
@ -5391,6 +5401,10 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return sorted.getDouble(sorted.length() - 1); return sorted.getDouble(sorted.length() - 1);
double pos = (quantile.doubleValue() / 100.0) * (double) (sorted.length() + 1); double pos = (quantile.doubleValue() / 100.0) * (double) (sorted.length() + 1);
if (pos < 1)
return sorted.getDouble(0);
else if (pos >= sorted.length())
return sorted.getDouble(sorted.length() - 1);
double fposition = FastMath.floor(pos); double fposition = FastMath.floor(pos);
int position = (int)fposition; int position = (int)fposition;

View File

@ -1,4 +1,4 @@
// Targeted by JavaCPP version 1.5.3-SNAPSHOT: DO NOT EDIT THIS FILE // Targeted by JavaCPP version 1.5.3: DO NOT EDIT THIS FILE
package org.nd4j.nativeblas; package org.nd4j.nativeblas;

View File

@ -1,4 +1,4 @@
// Targeted by JavaCPP version 1.5.3-SNAPSHOT: DO NOT EDIT THIS FILE // Targeted by JavaCPP version 1.5.3: DO NOT EDIT THIS FILE
package org.nd4j.nativeblas; package org.nd4j.nativeblas;
@ -16040,6 +16040,41 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
} }
// #endif // #endif
// #if NOT_EXCLUDED(OP_lstmLayerCell)
@Namespace("sd::ops") public static class lstmLayerCell extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public lstmLayerCell(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public lstmLayerCell(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public lstmLayerCell position(long position) {
return (lstmLayerCell)super.position(position);
}
public lstmLayerCell() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
// #if NOT_EXCLUDED(OP_lstmLayerCell)
@Namespace("sd::ops") public static class lstmLayerCellBp extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public lstmLayerCellBp(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public lstmLayerCellBp(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public lstmLayerCellBp position(long position) {
return (lstmLayerCellBp)super.position(position);
}
public lstmLayerCellBp() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
/** /**
@ -16169,6 +16204,25 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
} }
// #endif // #endif
//////////////////////////////////////////////////////////////////////////
// #if NOT_EXCLUDED(OP_lstmLayer)
@Namespace("sd::ops") public static class lstmLayer_bp extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public lstmLayer_bp(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public lstmLayer_bp(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public lstmLayer_bp position(long position) {
return (lstmLayer_bp)super.position(position);
}
public lstmLayer_bp() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
/** /**
@ -16336,6 +16390,24 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
} }
// #endif // #endif
// #if NOT_EXCLUDED(OP_gru)
@Namespace("sd::ops") public static class gru_bp extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public gru_bp(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public gru_bp(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public gru_bp position(long position) {
return (gru_bp)super.position(position);
}
public gru_bp() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
/** /**
* Implementation of operation "static RNN time sequences" with peep hole connections: * Implementation of operation "static RNN time sequences" with peep hole connections:

View File

@ -5607,6 +5607,13 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(exp, array.percentileNumber(75)); assertEquals(exp, array.percentileNumber(75));
} }
@Test
public void testPercentile5() {
val array = Nd4j.createFromArray(new int[]{1, 1982});
val perc = array.percentileNumber(75);
assertEquals(1982.f, perc.floatValue(), 1e-5f);
}
@Test @Test
public void testTadPercentile1() { public void testTadPercentile1() {
INDArray array = Nd4j.linspace(1, 10, 10, DataType.DOUBLE); INDArray array = Nd4j.linspace(1, 10, 10, DataType.DOUBLE);

View File

@ -17,6 +17,7 @@
package org.nd4j.linalg.shape.indexing; package org.nd4j.linalg.shape.indexing;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ErrorCollector; import org.junit.rules.ErrorCollector;
@ -190,6 +191,24 @@ public class IndexingTestsC extends BaseNd4jTest {
assertTrue(last10b.getDouble(i) == 20 + i); assertTrue(last10b.getDouble(i) == 20 + i);
} }
@Test
public void test1dSubarray_1() {
val data = Nd4j.linspace(DataType.FLOAT,0, 10, 1);
val exp = Nd4j.createFromArray(new float[]{3.f, 4.f});
val dataAtIndex = data.get(NDArrayIndex.interval(3, 5));
assertEquals(exp, dataAtIndex);
}
@Test
public void test1dSubarray_2() {
val data = Nd4j.linspace(DataType.FLOAT,1, 10, 1);
val exp = Nd4j.createFromArray(new float[]{4.f, 6.f});
val dataAtIndex = data.get(Nd4j.createFromArray(new int[]{3, 5}));
assertEquals(exp, dataAtIndex);
}
@Test @Test
public void testGet() { public void testGet() {
// System.out.println("Testing sub-array put and get with a 3D array ..."); // System.out.println("Testing sub-array put and get with a 3D array ...");