diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index 07a2bf9b8..46daa869b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -2045,8 +2045,18 @@ public abstract class BaseNDArray implements INDArray, Iterable { throw new ND4JIllegalArgumentException("Indices must be a vector or matrix."); } - if(indices.rows() == rank()) { - INDArray ret = Nd4j.create(indices.dataType(), indices.columns()); + if (rank() == 1) { + Preconditions.checkArgument(indices.rank() <= 1, "For 1D vector indices must be either scalar or vector as well"); + val ret = Nd4j.createUninitialized(this.dataType(), indices.length()); + for (int e = 0; e < indices.length(); e++) { + val idx = indices.getLong(e); + val value = getDouble(idx); + ret.putScalar(e, value); + } + + return ret; + } else if(indices.rows() == rank()) { + INDArray ret = Nd4j.create(this.dataType(), indices.columns()); for(int i = 0; i < indices.columns(); i++) { int[] specifiedIndex = indices.getColumn(i).dup().data().asInt(); @@ -5391,6 +5401,10 @@ public abstract class BaseNDArray implements INDArray, Iterable { return sorted.getDouble(sorted.length() - 1); double pos = (quantile.doubleValue() / 100.0) * (double) (sorted.length() + 1); + if (pos < 1) + return sorted.getDouble(0); + else if (pos >= sorted.length()) + return sorted.getDouble(sorted.length() - 1); double fposition = FastMath.floor(pos); int position = (int)fposition; diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index 8f30cdd82..6753b5ea1 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -1,4 +1,4 @@ -// Targeted by JavaCPP version 1.5.3-SNAPSHOT: DO NOT EDIT THIS FILE +// Targeted by JavaCPP version 1.5.3: DO NOT EDIT THIS FILE package org.nd4j.nativeblas; diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index e96325460..adddf5e42 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -1,4 +1,4 @@ -// Targeted by JavaCPP version 1.5.3-SNAPSHOT: DO NOT EDIT THIS FILE +// Targeted by JavaCPP version 1.5.3: DO NOT EDIT THIS FILE package org.nd4j.nativeblas; @@ -16040,6 +16040,41 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif +// #if NOT_EXCLUDED(OP_lstmLayerCell) + @Namespace("sd::ops") public static class lstmLayerCell extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lstmLayerCell(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lstmLayerCell(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lstmLayerCell position(long position) { + return (lstmLayerCell)super.position(position); + } + + public lstmLayerCell() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif +// #if NOT_EXCLUDED(OP_lstmLayerCell) + @Namespace("sd::ops") public static class lstmLayerCellBp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lstmLayerCellBp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lstmLayerCellBp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lstmLayerCellBp position(long position) { + return (lstmLayerCellBp)super.position(position); + } + + public lstmLayerCellBp() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + ////////////////////////////////////////////////////////////////////////// /** @@ -16169,6 +16204,25 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif + ////////////////////////////////////////////////////////////////////////// +// #if NOT_EXCLUDED(OP_lstmLayer) + @Namespace("sd::ops") public static class lstmLayer_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lstmLayer_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lstmLayer_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lstmLayer_bp position(long position) { + return (lstmLayer_bp)super.position(position); + } + + public lstmLayer_bp() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + ////////////////////////////////////////////////////////////////////////// /** @@ -16336,6 +16390,24 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif +// #if NOT_EXCLUDED(OP_gru) + @Namespace("sd::ops") public static class gru_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public gru_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public gru_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public gru_bp position(long position) { + return (gru_bp)super.position(position); + } + + public gru_bp() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + ////////////////////////////////////////////////////////////////////////// /** * Implementation of operation "static RNN time sequences" with peep hole connections: diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index 162e123b8..886384f4a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -5607,6 +5607,13 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(exp, array.percentileNumber(75)); } + @Test + public void testPercentile5() { + val array = Nd4j.createFromArray(new int[]{1, 1982}); + val perc = array.percentileNumber(75); + assertEquals(1982.f, perc.floatValue(), 1e-5f); + } + @Test public void testTadPercentile1() { INDArray array = Nd4j.linspace(1, 10, 10, DataType.DOUBLE); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java index 33a64c291..24c6b30d4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.shape.indexing; import lombok.extern.slf4j.Slf4j; +import lombok.val; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ErrorCollector; @@ -190,6 +191,24 @@ public class IndexingTestsC extends BaseNd4jTest { assertTrue(last10b.getDouble(i) == 20 + i); } + @Test + public void test1dSubarray_1() { + val data = Nd4j.linspace(DataType.FLOAT,0, 10, 1); + val exp = Nd4j.createFromArray(new float[]{3.f, 4.f}); + val dataAtIndex = data.get(NDArrayIndex.interval(3, 5)); + + assertEquals(exp, dataAtIndex); + } + + @Test + public void test1dSubarray_2() { + val data = Nd4j.linspace(DataType.FLOAT,1, 10, 1); + val exp = Nd4j.createFromArray(new float[]{4.f, 6.f}); + val dataAtIndex = data.get(Nd4j.createFromArray(new int[]{3, 5})); + + assertEquals(exp, dataAtIndex); + } + @Test public void testGet() { // System.out.println("Testing sub-array put and get with a 3D array ...");