Few minor fixes (#381)
* - 1D indexing fix - couple of new tests for 1D indexing Signed-off-by: raver119 <raver119@gmail.com> * percentile fix + test Signed-off-by: raver119 <raver119@gmail.com> * wrong signature used in test Signed-off-by: raver119 <raver119@gmail.com>master
parent
4247718f61
commit
12ba1fa406
|
@ -2045,8 +2045,18 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
throw new ND4JIllegalArgumentException("Indices must be a vector or matrix.");
|
throw new ND4JIllegalArgumentException("Indices must be a vector or matrix.");
|
||||||
}
|
}
|
||||||
|
|
||||||
if(indices.rows() == rank()) {
|
if (rank() == 1) {
|
||||||
INDArray ret = Nd4j.create(indices.dataType(), indices.columns());
|
Preconditions.checkArgument(indices.rank() <= 1, "For 1D vector indices must be either scalar or vector as well");
|
||||||
|
val ret = Nd4j.createUninitialized(this.dataType(), indices.length());
|
||||||
|
for (int e = 0; e < indices.length(); e++) {
|
||||||
|
val idx = indices.getLong(e);
|
||||||
|
val value = getDouble(idx);
|
||||||
|
ret.putScalar(e, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret;
|
||||||
|
} else if(indices.rows() == rank()) {
|
||||||
|
INDArray ret = Nd4j.create(this.dataType(), indices.columns());
|
||||||
|
|
||||||
for(int i = 0; i < indices.columns(); i++) {
|
for(int i = 0; i < indices.columns(); i++) {
|
||||||
int[] specifiedIndex = indices.getColumn(i).dup().data().asInt();
|
int[] specifiedIndex = indices.getColumn(i).dup().data().asInt();
|
||||||
|
@ -5391,6 +5401,10 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
return sorted.getDouble(sorted.length() - 1);
|
return sorted.getDouble(sorted.length() - 1);
|
||||||
|
|
||||||
double pos = (quantile.doubleValue() / 100.0) * (double) (sorted.length() + 1);
|
double pos = (quantile.doubleValue() / 100.0) * (double) (sorted.length() + 1);
|
||||||
|
if (pos < 1)
|
||||||
|
return sorted.getDouble(0);
|
||||||
|
else if (pos >= sorted.length())
|
||||||
|
return sorted.getDouble(sorted.length() - 1);
|
||||||
|
|
||||||
double fposition = FastMath.floor(pos);
|
double fposition = FastMath.floor(pos);
|
||||||
int position = (int)fposition;
|
int position = (int)fposition;
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// Targeted by JavaCPP version 1.5.3-SNAPSHOT: DO NOT EDIT THIS FILE
|
// Targeted by JavaCPP version 1.5.3: DO NOT EDIT THIS FILE
|
||||||
|
|
||||||
package org.nd4j.nativeblas;
|
package org.nd4j.nativeblas;
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// Targeted by JavaCPP version 1.5.3-SNAPSHOT: DO NOT EDIT THIS FILE
|
// Targeted by JavaCPP version 1.5.3: DO NOT EDIT THIS FILE
|
||||||
|
|
||||||
package org.nd4j.nativeblas;
|
package org.nd4j.nativeblas;
|
||||||
|
|
||||||
|
@ -16040,6 +16040,41 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
}
|
}
|
||||||
// #endif
|
// #endif
|
||||||
|
|
||||||
|
// #if NOT_EXCLUDED(OP_lstmLayerCell)
|
||||||
|
@Namespace("sd::ops") public static class lstmLayerCell extends DeclarableCustomOp {
|
||||||
|
static { Loader.load(); }
|
||||||
|
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||||
|
public lstmLayerCell(Pointer p) { super(p); }
|
||||||
|
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||||
|
public lstmLayerCell(long size) { super((Pointer)null); allocateArray(size); }
|
||||||
|
private native void allocateArray(long size);
|
||||||
|
@Override public lstmLayerCell position(long position) {
|
||||||
|
return (lstmLayerCell)super.position(position);
|
||||||
|
}
|
||||||
|
|
||||||
|
public lstmLayerCell() { super((Pointer)null); allocate(); }
|
||||||
|
private native void allocate();
|
||||||
|
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||||
|
}
|
||||||
|
// #endif
|
||||||
|
// #if NOT_EXCLUDED(OP_lstmLayerCell)
|
||||||
|
@Namespace("sd::ops") public static class lstmLayerCellBp extends DeclarableCustomOp {
|
||||||
|
static { Loader.load(); }
|
||||||
|
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||||
|
public lstmLayerCellBp(Pointer p) { super(p); }
|
||||||
|
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||||
|
public lstmLayerCellBp(long size) { super((Pointer)null); allocateArray(size); }
|
||||||
|
private native void allocateArray(long size);
|
||||||
|
@Override public lstmLayerCellBp position(long position) {
|
||||||
|
return (lstmLayerCellBp)super.position(position);
|
||||||
|
}
|
||||||
|
|
||||||
|
public lstmLayerCellBp() { super((Pointer)null); allocate(); }
|
||||||
|
private native void allocate();
|
||||||
|
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||||
|
}
|
||||||
|
// #endif
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
/**
|
/**
|
||||||
|
@ -16169,6 +16204,25 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
}
|
}
|
||||||
// #endif
|
// #endif
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
// #if NOT_EXCLUDED(OP_lstmLayer)
|
||||||
|
@Namespace("sd::ops") public static class lstmLayer_bp extends DeclarableCustomOp {
|
||||||
|
static { Loader.load(); }
|
||||||
|
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||||
|
public lstmLayer_bp(Pointer p) { super(p); }
|
||||||
|
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||||
|
public lstmLayer_bp(long size) { super((Pointer)null); allocateArray(size); }
|
||||||
|
private native void allocateArray(long size);
|
||||||
|
@Override public lstmLayer_bp position(long position) {
|
||||||
|
return (lstmLayer_bp)super.position(position);
|
||||||
|
}
|
||||||
|
|
||||||
|
public lstmLayer_bp() { super((Pointer)null); allocate(); }
|
||||||
|
private native void allocate();
|
||||||
|
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||||
|
}
|
||||||
|
// #endif
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
/**
|
/**
|
||||||
|
@ -16336,6 +16390,24 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
}
|
}
|
||||||
// #endif
|
// #endif
|
||||||
|
|
||||||
|
// #if NOT_EXCLUDED(OP_gru)
|
||||||
|
@Namespace("sd::ops") public static class gru_bp extends DeclarableCustomOp {
|
||||||
|
static { Loader.load(); }
|
||||||
|
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||||
|
public gru_bp(Pointer p) { super(p); }
|
||||||
|
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||||
|
public gru_bp(long size) { super((Pointer)null); allocateArray(size); }
|
||||||
|
private native void allocateArray(long size);
|
||||||
|
@Override public gru_bp position(long position) {
|
||||||
|
return (gru_bp)super.position(position);
|
||||||
|
}
|
||||||
|
|
||||||
|
public gru_bp() { super((Pointer)null); allocate(); }
|
||||||
|
private native void allocate();
|
||||||
|
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||||
|
}
|
||||||
|
// #endif
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
/**
|
/**
|
||||||
* Implementation of operation "static RNN time sequences" with peep hole connections:
|
* Implementation of operation "static RNN time sequences" with peep hole connections:
|
||||||
|
|
|
@ -5607,6 +5607,13 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
assertEquals(exp, array.percentileNumber(75));
|
assertEquals(exp, array.percentileNumber(75));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPercentile5() {
|
||||||
|
val array = Nd4j.createFromArray(new int[]{1, 1982});
|
||||||
|
val perc = array.percentileNumber(75);
|
||||||
|
assertEquals(1982.f, perc.floatValue(), 1e-5f);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testTadPercentile1() {
|
public void testTadPercentile1() {
|
||||||
INDArray array = Nd4j.linspace(1, 10, 10, DataType.DOUBLE);
|
INDArray array = Nd4j.linspace(1, 10, 10, DataType.DOUBLE);
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.nd4j.linalg.shape.indexing;
|
package org.nd4j.linalg.shape.indexing;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import lombok.val;
|
||||||
import org.junit.Rule;
|
import org.junit.Rule;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.rules.ErrorCollector;
|
import org.junit.rules.ErrorCollector;
|
||||||
|
@ -190,6 +191,24 @@ public class IndexingTestsC extends BaseNd4jTest {
|
||||||
assertTrue(last10b.getDouble(i) == 20 + i);
|
assertTrue(last10b.getDouble(i) == 20 + i);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void test1dSubarray_1() {
|
||||||
|
val data = Nd4j.linspace(DataType.FLOAT,0, 10, 1);
|
||||||
|
val exp = Nd4j.createFromArray(new float[]{3.f, 4.f});
|
||||||
|
val dataAtIndex = data.get(NDArrayIndex.interval(3, 5));
|
||||||
|
|
||||||
|
assertEquals(exp, dataAtIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void test1dSubarray_2() {
|
||||||
|
val data = Nd4j.linspace(DataType.FLOAT,1, 10, 1);
|
||||||
|
val exp = Nd4j.createFromArray(new float[]{4.f, 6.f});
|
||||||
|
val dataAtIndex = data.get(Nd4j.createFromArray(new int[]{3, 5}));
|
||||||
|
|
||||||
|
assertEquals(exp, dataAtIndex);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testGet() {
|
public void testGet() {
|
||||||
// System.out.println("Testing sub-array put and get with a 3D array ...");
|
// System.out.println("Testing sub-array put and get with a 3D array ...");
|
||||||
|
|
Loading…
Reference in New Issue