Merge pull request #7 from KonduitAI/asto_nd4s_10172019

KDTree optimization
master
Alexander Stoyakin 2019-10-23 12:11:25 +03:00 committed by GitHub
parent 35e6ffede4
commit f31661e13b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 831 additions and 282 deletions

View File

@ -18,6 +18,9 @@ package org.deeplearning4j.clustering.kdtree;
import lombok.val; import lombok.val;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.custom.KnnMinDistance;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import java.io.Serializable; import java.io.Serializable;
import java.util.ArrayList; import java.util.ArrayList;
@ -28,79 +31,103 @@ import java.util.List;
*/ */
public class HyperRect implements Serializable { public class HyperRect implements Serializable {
private List<Interval> points; //private List<Interval> points;
private float[] lowerEnds;
private float[] higherEnds;
private INDArray lowerEndsIND;
private INDArray higherEndsIND;
public HyperRect(List<Interval> points) { public HyperRect(float[] lowerEndsIn, float[] higherEndsIn) {
//this.points = points; this.lowerEnds = new float[lowerEndsIn.length];
this.points = new ArrayList<>(points.size()); this.higherEnds = new float[lowerEndsIn.length];
for (int i = 0; i < points.size(); ++i) { System.arraycopy(lowerEndsIn, 0 , this.lowerEnds, 0, lowerEndsIn.length);
Interval newInterval = new Interval(points.get(i).lower, points.get(i).higher); System.arraycopy(higherEndsIn, 0 , this.higherEnds, 0, higherEndsIn.length);
this.points.add(newInterval); lowerEndsIND = Nd4j.createFromArray(lowerEnds);
} higherEndsIND = Nd4j.createFromArray(higherEnds);
}
public HyperRect(float[] point) {
this(point, point);
}
public HyperRect(Pair<float[], float[]> ends) {
this(ends.getFirst(), ends.getSecond());
} }
public void enlargeTo(INDArray point) { public void enlargeTo(INDArray point) {
for (int i = 0; i < points.size(); i++) float[] pointAsArray = point.toFloatVector();
points.get(i).enlarge(point.getDouble(i)); for (int i = 0; i < lowerEnds.length; i++) {
float p = pointAsArray[i];
if (lowerEnds[i] > p)
lowerEnds[i] = p;
else if (higherEnds[i] < p)
higherEnds[i] = p;
}
} }
public static Pair<float[],float[]> point(INDArray vector) {
public static List<Interval> point(INDArray vector) { Pair<float[],float[]> ret = new Pair<>();
List<Interval> ret = new ArrayList<>(); float[] curr = new float[(int)vector.length()];
for (int i = 0; i < vector.length(); i++) { for (int i = 0; i < vector.length(); i++) {
double curr = vector.getDouble(i); curr[i] = vector.getFloat(i);
ret.add(new Interval(curr, curr));
} }
ret.setFirst(curr);
ret.setSecond(curr);
return ret; return ret;
} }
public List<Boolean> contains(INDArray hPoint) { /*public List<Boolean> contains(INDArray hPoint) {
List<Boolean> ret = new ArrayList<>(); List<Boolean> ret = new ArrayList<>();
for (int i = 0; i < hPoint.length(); i++)
ret.add(points.get(i).contains(hPoint.getDouble(i)));
return ret;
}
public double minDistance(INDArray hPoint) {
double ret = 0.0;
for (int i = 0; i < hPoint.length(); i++) { for (int i = 0; i < hPoint.length(); i++) {
double p = hPoint.getDouble(i); ret.add(lowerEnds[i] <= hPoint.getDouble(i) &&
Interval interval = points.get(i); higherEnds[i] >= hPoint.getDouble(i));
if (!interval.contains(p)) {
if (p < interval.lower)
ret += Math.pow((p - interval.lower), 2);
else
ret += Math.pow((p - interval.higher), 2);
}
} }
ret = Math.pow(ret, 0.5);
return ret; return ret;
}*/
public double minDistance(INDArray hPoint, INDArray output) {
Nd4j.exec(new KnnMinDistance(hPoint, lowerEndsIND, higherEndsIND, output));
return output.getFloat(0);
/*double ret = 0.0;
double[] pointAsArray = hPoint.toDoubleVector();
for (int i = 0; i < pointAsArray.length; i++) {
double p = pointAsArray[i];
if (!(lowerEnds[i] <= p || higherEnds[i] <= p)) {
if (p < lowerEnds[i])
ret += Math.pow((p - lowerEnds[i]), 2);
else
ret += Math.pow((p - higherEnds[i]), 2);
}
}
ret = Math.pow(ret, 0.5);
return ret;*/
} }
public HyperRect getUpper(INDArray hPoint, int desc) { public HyperRect getUpper(INDArray hPoint, int desc) {
Interval interval = points.get(desc); //Interval interval = points.get(desc);
double d = hPoint.getDouble(desc); float higher = higherEnds[desc];
if (interval.higher < d) float d = hPoint.getFloat(desc);
if (higher < d)
return null; return null;
HyperRect ret = new HyperRect(new ArrayList<>(points)); HyperRect ret = new HyperRect(lowerEnds,higherEnds);
Interval i2 = ret.points.get(desc); if (ret.lowerEnds[desc] < d)
if (i2.lower < d) ret.lowerEnds[desc] = d;
i2.lower = d;
return ret; return ret;
} }
public HyperRect getLower(INDArray hPoint, int desc) { public HyperRect getLower(INDArray hPoint, int desc) {
Interval interval = points.get(desc); //Interval interval = points.get(desc);
double d = hPoint.getDouble(desc); float lower = lowerEnds[desc];
if (interval.lower > d) float d = hPoint.getFloat(desc);
if (lower > d)
return null; return null;
HyperRect ret = new HyperRect(new ArrayList<>(points)); HyperRect ret = new HyperRect(lowerEnds,higherEnds);
Interval i2 = ret.points.get(desc); //Interval i2 = ret.points.get(desc);
if (i2.higher > d) if (ret.higherEnds[desc] > d)
i2.higher = d; ret.higherEnds[desc] = d;
return ret; return ret;
} }
@ -108,33 +135,10 @@ public class HyperRect implements Serializable {
public String toString() { public String toString() {
String retVal = ""; String retVal = "";
retVal += "["; retVal += "[";
for (val point : points) { for (int i = 0; i < lowerEnds.length; ++i) {
retVal += "(" + point.lower + " - " + point.higher + ") "; retVal += "(" + lowerEnds[i] + " - " + higherEnds[i] + ") ";
} }
retVal += "]"; retVal += "]";
return retVal; return retVal;
} }
public static class Interval {
private double lower, higher;
public Interval(double lower, double higher) {
this.lower = lower;
this.higher = higher;
}
public boolean contains(double point) {
return lower <= point || point <= higher;
}
public void enlarge(double p) {
if (lower > p)
lower = p;
else if (higher < p)
higher = p;
}
}
} }

View File

@ -56,7 +56,7 @@ public class KDTree implements Serializable {
if (root == null) { if (root == null) {
root = new KDNode(point); root = new KDNode(point);
rect = new HyperRect(HyperRect.point(point)); rect = new HyperRect(/*HyperRect.point(point)*/ point.toFloatVector());
} else { } else {
int disc = 0; int disc = 0;
KDNode node = root; KDNode node = root;
@ -125,15 +125,21 @@ public class KDTree implements Serializable {
return node.getPoint(); return node.getPoint();
} }
// Share this data for recursive calls of "knn"
private float currentDistance;
private INDArray currentPoint;
private INDArray minDistance = Nd4j.scalar(0.f);
public List<Pair<Double, INDArray>> knn(INDArray point, double distance) { public List<Pair<Float, INDArray>> knn(INDArray point, float distance) {
List<Pair<Double, INDArray>> best = new ArrayList<>(); List<Pair<Float, INDArray>> best = new ArrayList<>();
knn(root, point, rect, distance, best, 0); currentDistance = distance;
Collections.sort(best, new Comparator<Pair<Double, INDArray>>() { currentPoint = point;
knn(root, rect, best, 0);
Collections.sort(best, new Comparator<Pair<Float, INDArray>>() {
@Override @Override
public int compare(Pair<Double, INDArray> o1, Pair<Double, INDArray> o2) { public int compare(Pair<Float, INDArray> o1, Pair<Float, INDArray> o2) {
return Double.compare(o1.getKey(), o2.getKey()); return Float.compare(o1.getKey(), o2.getKey());
} }
}); });
@ -141,22 +147,21 @@ public class KDTree implements Serializable {
} }
private void knn(KDNode node, INDArray point, HyperRect rect, double dist, List<Pair<Double, INDArray>> best, private void knn(KDNode node, HyperRect rect, List<Pair<Float, INDArray>> best, int _disc) {
int _disc) { if (node == null || rect == null || rect.minDistance(currentPoint, minDistance) > currentDistance)
if (node == null || rect == null || rect.minDistance(point) > dist)
return; return;
int _discNext = (_disc + 1) % dims; int _discNext = (_disc + 1) % dims;
double distance = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(point,node.point)).getFinalResult() float distance = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(currentPoint,node.point, minDistance)).getFinalResult()
.doubleValue(); .floatValue();
if (distance <= dist) { if (distance <= currentDistance) {
best.add(Pair.of(distance, node.getPoint())); best.add(Pair.of(distance, node.getPoint()));
} }
HyperRect lower = rect.getLower(node.point, _disc); HyperRect lower = rect.getLower(node.point, _disc);
HyperRect upper = rect.getUpper(node.point, _disc); HyperRect upper = rect.getUpper(node.point, _disc);
knn(node.getLeft(), point, lower, dist, best, _discNext); knn(node.getLeft(), lower, best, _discNext);
knn(node.getRight(), point, upper, dist, best, _discNext); knn(node.getRight(), upper, best, _discNext);
} }
/** /**
@ -171,7 +176,7 @@ public class KDTree implements Serializable {
private Pair<Double, INDArray> nn(KDNode node, INDArray point, HyperRect rect, double dist, INDArray best, private Pair<Double, INDArray> nn(KDNode node, INDArray point, HyperRect rect, double dist, INDArray best,
int _disc) { int _disc) {
if (node == null || rect.minDistance(point) > dist) if (node == null || rect.minDistance(point, minDistance) > dist)
return Pair.of(Double.POSITIVE_INFINITY, null); return Pair.of(Double.POSITIVE_INFINITY, null);
int _discNext = (_disc + 1) % dims; int _discNext = (_disc + 1) % dims;

View File

@ -16,6 +16,8 @@
package org.deeplearning4j.clustering.kdtree; package org.deeplearning4j.clustering.kdtree;
import org.joda.time.Instant;
import org.nd4j.shade.guava.base.Stopwatch;
import org.nd4j.shade.guava.primitives.Doubles; import org.nd4j.shade.guava.primitives.Doubles;
import lombok.val; import lombok.val;
import org.deeplearning4j.clustering.BaseDL4JTest; import org.deeplearning4j.clustering.BaseDL4JTest;
@ -28,6 +30,7 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.nd4j.shade.guava.primitives.Floats;
import org.opencv.ml.KNearest; import org.opencv.ml.KNearest;
import java.util.ArrayList; import java.util.ArrayList;
@ -35,6 +38,8 @@ import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
@ -53,17 +58,17 @@ public class KDTreeTest extends BaseDL4JTest {
@Before @Before
public void setUp() { public void setUp() {
kdTree = new KDTree(2); kdTree = new KDTree(2);
double[] data = new double[]{7,2}; float[] data = new float[]{7,2};
kdTree.insert(Nd4j.createFromArray(data)); kdTree.insert(Nd4j.createFromArray(data));
data = new double[]{5,4}; data = new float[]{5,4};
kdTree.insert(Nd4j.createFromArray(data)); kdTree.insert(Nd4j.createFromArray(data));
data = new double[]{2,3}; data = new float[]{2,3};
kdTree.insert(Nd4j.createFromArray(data)); kdTree.insert(Nd4j.createFromArray(data));
data = new double[]{4,7}; data = new float[]{4,7};
kdTree.insert(Nd4j.createFromArray(data)); kdTree.insert(Nd4j.createFromArray(data));
data = new double[]{9,6}; data = new float[]{9,6};
kdTree.insert(Nd4j.createFromArray(data)); kdTree.insert(Nd4j.createFromArray(data));
data = new double[]{8,1}; data = new float[]{8,1};
kdTree.insert(Nd4j.createFromArray(data)); kdTree.insert(Nd4j.createFromArray(data));
} }
@ -168,26 +173,30 @@ public class KDTreeTest extends BaseDL4JTest {
@Test @Test
public void testKNN() { public void testKNN() {
int n = 10; int dimensions = 512;
// make a KD-tree of dimension {#n} int vectorsNo = 50000;
KDTree kdTree = new KDTree(n); // make a KD-tree of dimension {#dimensions}
for (int i = -1; i < n; i++) { Stopwatch stopwatch = Stopwatch.createStarted();
KDTree kdTree = new KDTree(dimensions);
for (int i = -1; i < vectorsNo; i++) {
// Insert a unit vector along each dimension // Insert a unit vector along each dimension
List<Double> vec = new ArrayList<>(n); INDArray indVec = Nd4j.rand(DataType.FLOAT, 1,dimensions);
// i = -1 ensures the origin is in the Tree
for (int k = 0; k < n; k++) {
vec.add((k == i) ? 1.0 : 0.0);
}
INDArray indVec = Nd4j.create(Nd4j.createBuffer(Doubles.toArray(vec)));
kdTree.insert(indVec); kdTree.insert(indVec);
} }
stopwatch.stop();
System.out.println("Time elapsed for " + kdTree.size() + " nodes construction is "+ stopwatch.elapsed(SECONDS));
Random rand = new Random(); Random rand = new Random();
// random point in the Hypercube // random point in the Hypercube
List<Double> pt = new ArrayList(n); List<Double> pt = new ArrayList(dimensions);
for (int k = 0; k < n; k++) { for (int k = 0; k < dimensions; k++) {
pt.add(rand.nextDouble() * 10.0); pt.add(rand.nextFloat() * 10.0);
} }
List<Pair<Double, INDArray>> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Doubles.toArray(pt))), 20.0); stopwatch.reset();
stopwatch.start();
List<Pair<Float, INDArray>> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Floats.toArray(pt))), 20.0f);
stopwatch.stop();
System.out.println("Time elapsed for Search is "+ stopwatch.elapsed(MILLISECONDS));
} }
@Test @Test
@ -195,15 +204,15 @@ public class KDTreeTest extends BaseDL4JTest {
int n = 2; int n = 2;
KDTree kdTree = new KDTree(n); KDTree kdTree = new KDTree(n);
double[] data = new double[]{3,3}; float[] data = new float[]{3,3};
kdTree.insert(Nd4j.createFromArray(data)); kdTree.insert(Nd4j.createFromArray(data));
data = new double[]{1,1}; data = new float[]{1,1};
kdTree.insert(Nd4j.createFromArray(data)); kdTree.insert(Nd4j.createFromArray(data));
data = new double[]{2,2}; data = new float[]{2,2};
kdTree.insert(Nd4j.createFromArray(data)); kdTree.insert(Nd4j.createFromArray(data));
data = new double[]{0,0}; data = new float[]{0,0};
List<Pair<Double, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 4.5); List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 4.5f);
assertEquals(1.0, result.get(0).getSecond().getDouble(0), 1e-5); assertEquals(1.0, result.get(0).getSecond().getDouble(0), 1e-5);
assertEquals(1.0, result.get(0).getSecond().getDouble(1), 1e-5); assertEquals(1.0, result.get(0).getSecond().getDouble(1), 1e-5);
@ -220,88 +229,88 @@ public class KDTreeTest extends BaseDL4JTest {
assertEquals(6, kdTree.size()); assertEquals(6, kdTree.size());
double[] data = new double[]{8,1}; float[] data = new float[]{8,1};
List<Pair<Double, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 10.0); List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f);
assertEquals(8.0, result.get(0).getSecond().getDouble(0), 1e-5); assertEquals(8.0, result.get(0).getSecond().getFloat(0), 1e-5);
assertEquals(1.0, result.get(0).getSecond().getDouble(1), 1e-5); assertEquals(1.0, result.get(0).getSecond().getFloat(1), 1e-5);
assertEquals(7.0, result.get(1).getSecond().getDouble(0), 1e-5); assertEquals(7.0, result.get(1).getSecond().getFloat(0), 1e-5);
assertEquals(2.0, result.get(1).getSecond().getDouble(1), 1e-5); assertEquals(2.0, result.get(1).getSecond().getFloat(1), 1e-5);
assertEquals(5.0, result.get(2).getSecond().getDouble(0), 1e-5); assertEquals(5.0, result.get(2).getSecond().getFloat(0), 1e-5);
assertEquals(4.0, result.get(2).getSecond().getDouble(1), 1e-5); assertEquals(4.0, result.get(2).getSecond().getFloat(1), 1e-5);
assertEquals(9.0, result.get(3).getSecond().getDouble(0), 1e-5); assertEquals(9.0, result.get(3).getSecond().getFloat(0), 1e-5);
assertEquals(6.0, result.get(3).getSecond().getDouble(1), 1e-5); assertEquals(6.0, result.get(3).getSecond().getFloat(1), 1e-5);
assertEquals(2.0, result.get(4).getSecond().getDouble(0), 1e-5); assertEquals(2.0, result.get(4).getSecond().getFloat(0), 1e-5);
assertEquals(3.0, result.get(4).getSecond().getDouble(1), 1e-5); assertEquals(3.0, result.get(4).getSecond().getFloat(1), 1e-5);
assertEquals(4.0, result.get(5).getSecond().getDouble(0), 1e-5); assertEquals(4.0, result.get(5).getSecond().getFloat(0), 1e-5);
assertEquals(7.0, result.get(5).getSecond().getDouble(1), 1e-5); assertEquals(7.0, result.get(5).getSecond().getFloat(1), 1e-5);
} }
@Test @Test
public void testKNN_2() { public void testKNN_2() {
double[] data = new double[]{8, 1}; float[] data = new float[]{8, 1};
List<Pair<Double, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 5.0); List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f);
assertEquals(8.0, result.get(0).getSecond().getDouble(0), 1e-5); assertEquals(8.0, result.get(0).getSecond().getFloat(0), 1e-5);
assertEquals(1.0, result.get(0).getSecond().getDouble(1), 1e-5); assertEquals(1.0, result.get(0).getSecond().getFloat(1), 1e-5);
assertEquals(7.0, result.get(1).getSecond().getDouble(0), 1e-5); assertEquals(7.0, result.get(1).getSecond().getFloat(0), 1e-5);
assertEquals(2.0, result.get(1).getSecond().getDouble(1), 1e-5); assertEquals(2.0, result.get(1).getSecond().getFloat(1), 1e-5);
assertEquals(5.0, result.get(2).getSecond().getDouble(0), 1e-5); assertEquals(5.0, result.get(2).getSecond().getFloat(0), 1e-5);
assertEquals(4.0, result.get(2).getSecond().getDouble(1), 1e-5); assertEquals(4.0, result.get(2).getSecond().getFloat(1), 1e-5);
} }
@Test @Test
public void testKNN_3() { public void testKNN_3() {
double[] data = new double[]{2, 3}; float[] data = new float[]{2, 3};
val result = kdTree.knn(Nd4j.createFromArray(data), 10.0); List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f);
assertEquals(2.0, result.get(0).getSecond().getDouble(0), 1e-5); assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5);
assertEquals(3.0, result.get(0).getSecond().getDouble(1), 1e-5); assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5);
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5);
assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5); assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5);
assertEquals(4.0, result.get(2).getSecond().getDouble(0), 1e-5); assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5);
assertEquals(7.0, result.get(2).getSecond().getDouble(1), 1e-5); assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5);
assertEquals(7.0, result.get(3).getSecond().getDouble(0), 1e-5); assertEquals(7.0, result.get(3).getSecond().getFloat(0), 1e-5);
assertEquals(2.0, result.get(3).getSecond().getDouble(1), 1e-5); assertEquals(2.0, result.get(3).getSecond().getFloat(1), 1e-5);
assertEquals(8.0, result.get(4).getSecond().getDouble(0), 1e-5); assertEquals(8.0, result.get(4).getSecond().getFloat(0), 1e-5);
assertEquals(1.0, result.get(4).getSecond().getDouble(1), 1e-5); assertEquals(1.0, result.get(4).getSecond().getFloat(1), 1e-5);
assertEquals(9.0, result.get(5).getSecond().getDouble(0), 1e-5); assertEquals(9.0, result.get(5).getSecond().getFloat(0), 1e-5);
assertEquals(6.0, result.get(5).getSecond().getDouble(1), 1e-5); assertEquals(6.0, result.get(5).getSecond().getFloat(1), 1e-5);
} }
@Test @Test
public void testKNN_4() { public void testKNN_4() {
double[] data = new double[]{2, 3}; float[] data = new float[]{2, 3};
val result = kdTree.knn(Nd4j.createFromArray(data), 5.0); List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f);
assertEquals(2.0, result.get(0).getSecond().getDouble(0), 1e-5); assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5);
assertEquals(3.0, result.get(0).getSecond().getDouble(1), 1e-5); assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5);
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5);
assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5); assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5);
assertEquals(4.0, result.get(2).getSecond().getDouble(0), 1e-5); assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5);
assertEquals(7.0, result.get(2).getSecond().getDouble(1), 1e-5); assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5);
} }
@Test @Test
public void testKNN_5() { public void testKNN_5() {
double[] data = new double[]{2, 3}; float[] data = new float[]{2, 3};
val result = kdTree.knn(Nd4j.createFromArray(data), 20.0); List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 20.0f);
assertEquals(2.0, result.get(0).getSecond().getDouble(0), 1e-5); assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5);
assertEquals(3.0, result.get(0).getSecond().getDouble(1), 1e-5); assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5);
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5);
assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5); assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5);
assertEquals(4.0, result.get(2).getSecond().getDouble(0), 1e-5); assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5);
assertEquals(7.0, result.get(2).getSecond().getDouble(1), 1e-5); assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5);
assertEquals(7.0, result.get(3).getSecond().getDouble(0), 1e-5); assertEquals(7.0, result.get(3).getSecond().getFloat(0), 1e-5);
assertEquals(2.0, result.get(3).getSecond().getDouble(1), 1e-5); assertEquals(2.0, result.get(3).getSecond().getFloat(1), 1e-5);
assertEquals(8.0, result.get(4).getSecond().getDouble(0), 1e-5); assertEquals(8.0, result.get(4).getSecond().getFloat(0), 1e-5);
assertEquals(1.0, result.get(4).getSecond().getDouble(1), 1e-5); assertEquals(1.0, result.get(4).getSecond().getFloat(1), 1e-5);
assertEquals(9.0, result.get(5).getSecond().getDouble(0), 1e-5); assertEquals(9.0, result.get(5).getSecond().getFloat(0), 1e-5);
assertEquals(6.0, result.get(5).getSecond().getDouble(1), 1e-5); assertEquals(6.0, result.get(5).getSecond().getFloat(1), 1e-5);
} }
@Test @Test
public void test_KNN_6() { public void test_KNN_6() {
double[] data = new double[]{4, 6}; float[] data = new float[]{4, 6};
val result = kdTree.knn(Nd4j.createFromArray(data), 10.0); List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f);
assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5); assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5);
assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5); assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5);
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5);
@ -318,8 +327,8 @@ public class KDTreeTest extends BaseDL4JTest {
@Test @Test
public void test_KNN_7() { public void test_KNN_7() {
double[] data = new double[]{4, 6}; float[] data = new float[]{4, 6};
val result = kdTree.knn(Nd4j.createFromArray(data), 5.0); List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f);
assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5); assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5);
assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5); assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5);
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5);
@ -334,8 +343,8 @@ public class KDTreeTest extends BaseDL4JTest {
@Test @Test
public void test_KNN_8() { public void test_KNN_8() {
double[] data = new double[]{4, 6}; float[] data = new float[]{4, 6};
val result = kdTree.knn(Nd4j.createFromArray(data), 20.0); List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 20.0f);
assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5); assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5);
assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5); assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5);
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5);
@ -392,12 +401,12 @@ public class KDTreeTest extends BaseDL4JTest {
Duration duration = new Duration(start, end); Duration duration = new Duration(start, end);
System.out.println("Elapsed time for tree construction " + duration.getStandardSeconds() + " " + duration.getMillis()); System.out.println("Elapsed time for tree construction " + duration.getStandardSeconds() + " " + duration.getMillis());
List<Double> pt = new ArrayList(num); List<Float> pt = new ArrayList(num);
for (int k = 0; k < n; k++) { for (int k = 0; k < n; k++) {
pt.add((double)(num / 2)); pt.add((float)(num / 2));
} }
start = System.currentTimeMillis(); start = System.currentTimeMillis();
List<Pair<Double, INDArray>> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Doubles.toArray(pt))), 20.0); List<Pair<Float, INDArray>> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Doubles.toArray(pt))), 20.0f);
end = System.currentTimeMillis(); end = System.currentTimeMillis();
duration = new Duration(start, end); duration = new Duration(start, end);
long elapsed = end - start; long elapsed = end - start;

View File

@ -39,6 +39,7 @@
#include <ops/declarable/headers/datatypes.h> #include <ops/declarable/headers/datatypes.h>
#include <ops/declarable/headers/third_party.h> #include <ops/declarable/headers/third_party.h>
#include <ops/declarable/headers/tests.h> #include <ops/declarable/headers/tests.h>
#include <ops/declarable/headers/kernels.h>
#include <ops/declarable/headers/BarnesHutTsne.h> #include <ops/declarable/headers/BarnesHutTsne.h>
#include <dll.h> #include <dll.h>
#include <helpers/shape.h> #include <helpers/shape.h>

View File

@ -0,0 +1,59 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_knn_mindistance)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/knn.h>
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(knn_mindistance, 3, 1, false, 0, 0) {
auto input = INPUT_VARIABLE(0);
auto lowest = INPUT_VARIABLE(1);
auto highest = INPUT_VARIABLE(2);
auto output = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(input->lengthOf() == lowest->lengthOf() && input->lengthOf() == highest->lengthOf(), 0, "knn_mindistance: all input arrays must have same length");
REQUIRE_TRUE(input->dataType() == lowest->dataType() && input->dataType() == highest->dataType() && input->dataType() == output->dataType(), 0, "knn_mindistance: all inputs must have the same data type");
helpers::knn_mindistance(*input, *lowest, *highest, *output);
return Status::OK();
}
DECLARE_SHAPE_FN(knn_mindistance) {
auto input = inputShape->at(0);
// always return scalar here
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(input)));
}
DECLARE_TYPES(knn_mindistance) {
getOpDescriptor()
->setAllowedInputTypes({ALL_FLOATS})
->setAllowedOutputTypes({ALL_FLOATS});
}
}
}
#endif

View File

@ -0,0 +1,34 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef LIBND4J_KERNELS_H
#define LIBND4J_KERNELS_H
#include <ops/declarable/headers/common.h>
namespace nd4j {
namespace ops {
#if NOT_EXCLUDED(OP_knn_mindistance)
DECLARE_CUSTOM_OP(knn_mindistance, 3, 1, false, 0, 0);
#endif
}
}
#endif //LIBND4J_KERNELS_H

View File

@ -0,0 +1,62 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <ops/declarable/helpers/knn.h>
namespace nd4j {
namespace ops {
namespace helpers {
template <typename T>
void mindistance_(const void* vinput, const void *vlow, const void *vhigh, int32_t length, void *vout) {
auto input = reinterpret_cast<const T*>(vinput);
auto low = reinterpret_cast<const T*>(vlow);
auto high = reinterpret_cast<const T*>(vhigh);
auto output = reinterpret_cast<T*>(vout);
T res = 0.0f;
T po = 2.f;
T o = 1.f;
#pragma omp simd reduction(sumT:res)
for (auto e = 0; e < length; e++) {
T p = input[e];
T l = low[e];
T h = high[e];
if (!(l <= p || h <= p)) {
if (p < l)
res += nd4j::math::nd4j_pow<T, T, T>((p - o), po);
else
res += nd4j::math::nd4j_pow<T, T, T>((p - h), po);
}
}
output[0] = nd4j::math::nd4j_pow<T, T, T>(res, (T) 0.5f);
}
void knn_mindistance(const NDArray &input, const NDArray &lowest, const NDArray &highest, NDArray &output) {
NDArray::preparePrimaryUse({&output}, {&input, &lowest, &highest});
BUILD_SINGLE_SELECTOR(input.dataType(), mindistance_, (input.getBuffer(), lowest.getBuffer(), highest.getBuffer(), input.lengthOf(), output.buffer()), FLOAT_TYPES);
NDArray::registerPrimaryUse({&output}, {&input, &lowest, &highest});
}
}
}
}

View File

@ -0,0 +1,34 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef SAMEDIFF_KNN_H
#define SAMEDIFF_KNN_H
#include <ops/declarable/helpers/helpers.h>
namespace nd4j {
namespace ops {
namespace helpers {
void knn_mindistance(const NDArray &input, const NDArray &lowest, const NDArray &highest, NDArray &output);
}
}
}
#endif //SAMEDIFF_KNN_H

View File

@ -217,7 +217,7 @@ namespace nd4j {
auto var = ctx.variable(pair); auto var = ctx.variable(pair);
auto shape = var->getNDArray()->shapeInfo(); auto shape = var->getNDArray()->shapeInfo();
if (!shape::equalsSoft(out, shape)) { if (!shape::equalsSoft(out, shape) || shape::isEmpty(out) != shape::isEmpty(shape)) {
auto eShape = ShapeUtils::shapeAsString(out); auto eShape = ShapeUtils::shapeAsString(out);
auto aShape = ShapeUtils::shapeAsString(shape); auto aShape = ShapeUtils::shapeAsString(shape);
@ -237,7 +237,7 @@ namespace nd4j {
ctx.setOutputArray(idx, outArr, true); ctx.setOutputArray(idx, outArr, true);
} else { } else {
auto array = fout[idx]; auto array = fout[idx];
if (!shape::equalsSoft(out, array->shapeInfo())) { if (!shape::equalsSoft(out, array->shapeInfo()) || shape::isEmpty(out) != array->isEmpty()) {
auto eShape = ShapeUtils::shapeAsString(out); auto eShape = ShapeUtils::shapeAsString(out);
auto aShape = ShapeUtils::shapeAsString(array->shapeInfo()); auto aShape = ShapeUtils::shapeAsString(array->shapeInfo());

View File

@ -134,3 +134,20 @@ TEST_F(DeclarableOpsTests16, test_hamming_distance_1) {
delete result; delete result;
} }
TEST_F(DeclarableOpsTests16, test_knn_mindistance_1) {
auto input = NDArrayFactory::create<float>('c', {512});
auto low = NDArrayFactory::create<float>('c', {512});
auto high = NDArrayFactory::create<float>('c', {512});
auto output = NDArrayFactory::create<float>(0.0f);
input.linspace(1.0);
low.linspace(1.0);
high.linspace(1.0);
nd4j::ops::knn_mindistance op;
auto result = op.execute({&input, &low, &high}, {&output}, {}, {}, {});
ASSERT_EQ(Status::OK(), result);
}

View File

@ -0,0 +1,23 @@
package org.nd4j.linalg.api.ops.custom;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
public class KnnMinDistance extends DynamicCustomOp {
public KnnMinDistance() {
}
public KnnMinDistance(INDArray point, INDArray lowest, INDArray highest, INDArray distance) {
inputArguments.add(point);
inputArguments.add(lowest);
inputArguments.add(highest);
outputArguments.add(distance);
}
@Override
public String opName() {
return "knn_mindistance";
}
}

View File

@ -52,20 +52,26 @@ public class CudaOpContext extends BaseOpContext implements OpContext {
@Override @Override
public void setIArguments(long... arguments) { public void setIArguments(long... arguments) {
super.setIArguments(arguments); if (arguments.length > 0) {
nativeOps.setGraphContextIArguments(context, new LongPointer(arguments), arguments.length); super.setIArguments(arguments);
nativeOps.setGraphContextIArguments(context, new LongPointer(arguments), arguments.length);
}
} }
@Override @Override
public void setBArguments(boolean... arguments) { public void setBArguments(boolean... arguments) {
super.setBArguments(arguments); if (arguments.length > 0) {
nativeOps.setGraphContextBArguments(context, new BooleanPointer(arguments), arguments.length); super.setBArguments(arguments);
nativeOps.setGraphContextBArguments(context, new BooleanPointer(arguments), arguments.length);
}
} }
@Override @Override
public void setTArguments(double... arguments) { public void setTArguments(double... arguments) {
super.setTArguments(arguments); if (arguments.length > 0) {
nativeOps.setGraphContextTArguments(context, new DoublePointer(arguments), arguments.length); super.setTArguments(arguments);
nativeOps.setGraphContextTArguments(context, new DoublePointer(arguments), arguments.length);
}
} }
@Override @Override

View File

@ -49,20 +49,26 @@ public class CpuOpContext extends BaseOpContext implements OpContext {
@Override @Override
public void setIArguments(long... arguments) { public void setIArguments(long... arguments) {
super.setIArguments(arguments); if (arguments.length > 0) {
nativeOps.setGraphContextIArguments(context, new LongPointer(arguments), arguments.length); super.setIArguments(arguments);
nativeOps.setGraphContextIArguments(context, new LongPointer(arguments), arguments.length);
}
} }
@Override @Override
public void setBArguments(boolean... arguments) { public void setBArguments(boolean... arguments) {
super.setBArguments(arguments); if (arguments.length > 0) {
nativeOps.setGraphContextBArguments(context, new BooleanPointer(arguments), arguments.length); super.setBArguments(arguments);
nativeOps.setGraphContextBArguments(context, new BooleanPointer(arguments), arguments.length);
}
} }
@Override @Override
public void setTArguments(double... arguments) { public void setTArguments(double... arguments) {
super.setTArguments(arguments); if (arguments.length > 0) {
nativeOps.setGraphContextTArguments(context, new DoublePointer(arguments), arguments.length); super.setTArguments(arguments);
nativeOps.setGraphContextTArguments(context, new DoublePointer(arguments), arguments.length);
};
} }
@Override @Override

View File

@ -1,4 +1,4 @@
// Targeted by JavaCPP version 1.5.2-SNAPSHOT: DO NOT EDIT THIS FILE // Targeted by JavaCPP version 1.5.1-1: DO NOT EDIT THIS FILE
package org.nd4j.nativeblas; package org.nd4j.nativeblas;
@ -4703,6 +4703,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
* k - depth * k - depth
* value - scalar value to assign * value - scalar value to assign
*/ */
public native void p(@Cast("const Nd4jLong") long i, @Cast("const Nd4jLong") long j, @Cast("const Nd4jLong") long k, @Cast("const Nd4jLong") long l, @Const @ByRef NDArray value);
/** /**
* creates array which points on certain sub-range of this array, sub-range is defined by given indices * creates array which points on certain sub-range of this array, sub-range is defined by given indices
@ -4931,7 +4932,7 @@ NDArray NDArray::operator()(const Nd4jLong i) const {
} else { } else {
Nd4jLong idx[MAX_RANK]; Nd4jLong idx[MAX_RANK];
shape::ind2subC(rankOf(), shapeOf(), i, idx); shape::ind2subC(rankOf(), shapeOf(), i, idx);
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), idx, rankOf()); auto xOffset = shape::getOffset(getShapeInfo(), idx);
auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT()); auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT());
NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace()));
@ -4962,7 +4963,7 @@ NDArray& NDArray::operator()(const Nd4jLong i) {
} else { } else {
Nd4jLong idx[MAX_RANK]; Nd4jLong idx[MAX_RANK];
shape::ind2subC(rankOf(), shapeOf(), i, idx); shape::ind2subC(rankOf(), shapeOf(), i, idx);
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), idx, rankOf()); auto xOffset = shape::getOffset(getShapeInfo(), idx);
auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT()); auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT());
NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace()));
@ -4979,7 +4980,7 @@ NDArray NDArray::operator()(const Nd4jLong i, const Nd4jLong j) const {
throw std::invalid_argument("NDArray::operator(i,j): one of input indexes is out of array length or rank!=2 !"); throw std::invalid_argument("NDArray::operator(i,j): one of input indexes is out of array length or rank!=2 !");
Nd4jLong coords[2] = {i, j}; Nd4jLong coords[2] = {i, j};
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); auto xOffset = shape::getOffset(getShapeInfo(), coords);
// TODO: do we really want a view here? // TODO: do we really want a view here?
auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT()); auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT());
@ -4995,7 +4996,7 @@ NDArray& NDArray::operator()(const Nd4jLong i, const Nd4jLong j) {
throw std::invalid_argument("NDArray::operator(i,j): one of input indexes is out of array length or rank!=2 !"); throw std::invalid_argument("NDArray::operator(i,j): one of input indexes is out of array length or rank!=2 !");
Nd4jLong coords[2] = {i, j}; Nd4jLong coords[2] = {i, j};
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); auto xOffset = shape::getOffset(getShapeInfo(), coords);
auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT()); auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT());
NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace()));
@ -5014,7 +5015,7 @@ NDArray NDArray::operator()(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k
throw std::invalid_argument("NDArray::operator(i,j,k): one of input indexes is out of array length or rank!=3 !"); throw std::invalid_argument("NDArray::operator(i,j,k): one of input indexes is out of array length or rank!=3 !");
Nd4jLong coords[3] = {i, j, k}; Nd4jLong coords[3] = {i, j, k};
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); auto xOffset = shape::getOffset(getShapeInfo(), coords);
auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT()); auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT());
NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace()));
@ -5031,7 +5032,7 @@ NDArray& NDArray::operator()(const Nd4jLong i, const Nd4jLong j, const Nd4jLong
throw std::invalid_argument("NDArray::operator(i,j,k): one of input indexes is out of array length or rank!=3 !"); throw std::invalid_argument("NDArray::operator(i,j,k): one of input indexes is out of array length or rank!=3 !");
Nd4jLong coords[3] = {i, j, k}; Nd4jLong coords[3] = {i, j, k};
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); auto xOffset = shape::getOffset(getShapeInfo(), coords);
auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT()); auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT());
NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace()));
@ -5047,7 +5048,7 @@ NDArray NDArray::operator()(const Nd4jLong t, const Nd4jLong u, const Nd4jLong v
throw std::invalid_argument("NDArray::operator(t,u,v,w): one of input indexes is out of array length or rank!=4 !"); throw std::invalid_argument("NDArray::operator(t,u,v,w): one of input indexes is out of array length or rank!=4 !");
Nd4jLong coords[4] = {t, u, v, w}; Nd4jLong coords[4] = {t, u, v, w};
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); auto xOffset = shape::getOffset(getShapeInfo(), coords);
auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT()); auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT());
NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace()));
@ -5061,7 +5062,7 @@ NDArray& NDArray::operator()(const Nd4jLong t, const Nd4jLong u, const Nd4jLong
throw std::invalid_argument("NDArray::operator(t,u,v,w): one of input indexes is out of array length or rank!=4 !"); throw std::invalid_argument("NDArray::operator(t,u,v,w): one of input indexes is out of array length or rank!=4 !");
Nd4jLong coords[4] = {t, u, v, w}; Nd4jLong coords[4] = {t, u, v, w};
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); auto xOffset = shape::getOffset(getShapeInfo(), coords);
// FIXME // FIXME
auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT()); auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT());
@ -5077,7 +5078,7 @@ NDArray NDArray::operator()(const Nd4jLong* idx) const {
if (idx[i] >= sizeAt(i)) if (idx[i] >= sizeAt(i))
throw std::invalid_argument("NDArray::operator(const Nd4jLong* idx): input index is out of dimension length !"); throw std::invalid_argument("NDArray::operator(const Nd4jLong* idx): input index is out of dimension length !");
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), idx, rankOf()); auto xOffset = shape::getOffset(getShapeInfo(), idx);
auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT()); auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT());
NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace()));
@ -5092,7 +5093,7 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) {
if (idx[i] >= sizeAt(i)) if (idx[i] >= sizeAt(i))
throw std::invalid_argument("NDArray::operator(const Nd4jLong* idx): input index is out of dimension length !"); throw std::invalid_argument("NDArray::operator(const Nd4jLong* idx): input index is out of dimension length !");
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), idx, rankOf()); auto xOffset = shape::getOffset(getShapeInfo(), idx);
auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT()); auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT());
NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace()));
@ -7958,9 +7959,7 @@ public static final int PREALLOC_SIZE = 33554432;
* @param indices the indices to iterate over * @param indices the indices to iterate over
* @return the double at the specified index * @return the double at the specified index
*/ */
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer stride, @Cast("const Nd4jLong*") LongPointer indices, int rank);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer stride, @Cast("const Nd4jLong*") LongBuffer indices, int rank);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] stride, @Cast("const Nd4jLong*") long[] indices, int rank);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer indices, @Cast("Nd4jLong") long baseOffset/*=0*/); @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer indices); @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer indices);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer indices, @Cast("Nd4jLong") long baseOffset/*=0*/); @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
@ -7981,34 +7980,26 @@ public static final int PREALLOC_SIZE = 33554432;
/** /**
* Convert a linear index to the corresponding coordinates * Convert a linear index to the corresponding coordinates
* for example if shape is {2, 4}, then index 5 corresponds to following coordinates * for example if shape is {2, 4}, then index 5 corresponds to coordinates [1, 1]
* -> [1, 1] in case of c order
* -> [1, 2] in case of f order
*/ */
@Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong") long arrLen, @Cast("Nd4jLong*") LongPointer coords, byte order/*='c'*/); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords);
@Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong") long arrLen, @Cast("Nd4jLong*") LongPointer coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords);
@Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong") long arrLen, @Cast("Nd4jLong*") LongBuffer coords, byte order/*='c'*/); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords);
@Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong") long arrLen, @Cast("Nd4jLong*") LongBuffer coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer coords);
@Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong") long arrLen, @Cast("Nd4jLong*") long[] coords, byte order/*='c'*/); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer coords);
@Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong") long arrLen, @Cast("Nd4jLong*") long[] coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] coords);
@Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong*") LongPointer coords, byte order/*='c'*/);
@Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong*") LongPointer coords);
@Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong*") LongBuffer coords, byte order/*='c'*/);
@Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong*") LongBuffer coords);
@Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong*") long[] coords, byte order/*='c'*/);
@Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong*") long[] coords);
/** /**
* Convert coordinates to the corresponding linear index (sequence number in other words) * Convert coordinates to the corresponding linear index (sequence number in other words)
* for example if shape is {2, 4}, then: * for example if shape is {2, 4} and coordinates [1, 1] then index 5 is returned
* in case of c order and coordinates [1, 1] index 5 is returned
* in case of f order and coordinates [1, 2] index 5 is returned
*/ */
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords, byte order/*='c'*/); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords, byte order/*='c'*/);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords, byte order/*='c'*/);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords);
/** /**
@ -8020,36 +8011,16 @@ public static final int PREALLOC_SIZE = 33554432;
*/ */
/* calculates an array buffer offset for given "index" using following formula: offset = coord_0*stride_0 + coord_1*stride_1 + ... + coord_{rank-1}*stride_{rank-1} /* calculates an array buffer offset for given "index" using following formula: offset = coord_0*stride_0 + coord_1*stride_1 + ... + coord_{rank-1}*stride_{rank-1}
* arrLen - array length
*/ */
@Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") IntPointer shapeInfo, @Cast("uint") int arrLen); @Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") IntPointer shapeInfo);
@Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") IntBuffer shapeInfo, @Cast("uint") int arrLen); @Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") IntBuffer shapeInfo);
@Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") int[] shapeInfo, @Cast("uint") int arrLen); @Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") int[] shapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong") long arrLen); @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong") long arrLen); @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong") long arrLen); @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOrderOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong") long arrLen, byte order); @Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer lShapeInfo, @Cast("const uint*") IntPointer uShapeInfo, @Cast("const bool") boolean useUnsigned);
@Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOrderOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong") long arrLen, byte order); @Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer lShapeInfo, @Cast("const uint*") IntBuffer uShapeInfo, @Cast("const bool") boolean useUnsigned);
@Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOrderOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong") long arrLen, byte order); @Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] lShapeInfo, @Cast("const uint*") int[] uShapeInfo, @Cast("const bool") boolean useUnsigned);
@Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer lShapeInfo, @Cast("const uint*") IntPointer uShapeInfo, @Cast("Nd4jLong") long arrLen, @Cast("const bool") boolean useUnsigned);
@Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer lShapeInfo, @Cast("const uint*") IntBuffer uShapeInfo, @Cast("Nd4jLong") long arrLen, @Cast("const bool") boolean useUnsigned);
@Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] lShapeInfo, @Cast("const uint*") int[] uShapeInfo, @Cast("Nd4jLong") long arrLen, @Cast("const bool") boolean useUnsigned);
/**
* Compute the real linear indices for the given shape and stride
*/
@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer computeIndices(int rank, @Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride);
@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer computeIndices(int rank, @Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride);
@Namespace("shape") public static native @Cast("Nd4jLong*") long[] computeIndices(int rank, @Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] stride);
/**
* Compute the real linear indices for the
* given shape buffer. Shape,stride and rank are derived
* from the buffer
*/
@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer computeIndices( @Cast("Nd4jLong*") LongPointer shapeBuffer);
@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer computeIndices( @Cast("Nd4jLong*") LongBuffer shapeBuffer);
@Namespace("shape") public static native @Cast("Nd4jLong*") long[] computeIndices( @Cast("Nd4jLong*") long[] shapeBuffer);
@Namespace("shape") public static native void printShapeInfo(@Cast("Nd4jLong*") LongPointer shapeInfo); @Namespace("shape") public static native void printShapeInfo(@Cast("Nd4jLong*") LongPointer shapeInfo);
@Namespace("shape") public static native void printShapeInfo(@Cast("Nd4jLong*") LongBuffer shapeInfo); @Namespace("shape") public static native void printShapeInfo(@Cast("Nd4jLong*") LongBuffer shapeInfo);
@ -8328,20 +8299,62 @@ public static final int PREALLOC_SIZE = 33554432;
* for the given rank and shape. * for the given rank and shape.
*/ */
/** //////////////////////////////////////////////////////////////////////
* Compute the real linear indices for the given shape and stride
*/
/**
* Compute the real linear indices for the given shape and stride
*/
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
// //////////////////////////////////////////////////////////////////////
// INLINEDEF _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong arrLen) {
// const Nd4jLong ews = shapeInfo[shapeInfo[0] + shapeInfo[0] + 2];
// if(ews > 0 && order(shapeInfo) == 'c')
// if (ews == 1)
// return index;
// else
// return ews * index;
// Nd4jLong offset = 0;
// Nd4jLong rank = shapeInfo[0];
// for(int i = 1; i <= shapeInfo[0]; ++i) {
// arrLen /= shapeInfo[i];
// if(arrLen > 0 && shapeInfo[i] > 1) {
// offset += (index / arrLen) * shapeInfo[i + rank];
// index %= arrLen;
// }
// }
// return offset;
// }
// INLINEDEF _CUDA_HD uint getIndexOffset(uint index, const uint *shapeInfo, uint arrLen) {
// const uint rank = shapeInfo[0];
// const uint ews = shapeInfo[rank + rank + 2];
// if(ews > 0 && shapeInfo[rank + rank + 3] == 99)
// if (ews == 1)
// return index;
// else
// return ews * index;
// uint offset = 0;
// for(uint i = 1; i <= rank; ++i) {
// arrLen /= shapeInfo[i];
// if(arrLen > 0 && shapeInfo[i] > 1) {
// offset += (index / arrLen) * shapeInfo[i + rank];
// index %= arrLen;
// }
// }
// return offset;
// }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
/** /**
@ -8710,6 +8723,10 @@ public static final int PREALLOC_SIZE = 33554432;
* @return the double at the specified index * @return the double at the specified index
*/ */
//////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////
@ -9038,6 +9055,8 @@ public static final int PREALLOC_SIZE = 33554432;
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
@ -13850,6 +13869,31 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
} }
// #endif // #endif
/**
* This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes:
* 1) if shapes are equal that's pairwise operation, result will have the same shape.
* 2) if shape X is scalar and shape Y is array - result will have shape equal to Y.
* 3) if shape X is array and shape Y is scalar - result will have shape equal to X.
* 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result.
*
* This operation returns Z = Divide(X, Y) with exception, 0 if Y = 0
*/
// #if NOT_EXCLUDED(OP_divide_no_nan)
@Namespace("nd4j::ops") public static class divide_no_nan extends BroadcastableOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public divide_no_nan(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public divide_no_nan(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public divide_no_nan position(long position) {
return (divide_no_nan)super.position(position);
}
public divide_no_nan() { super((Pointer)null); allocate(); }
private native void allocate();
}
// #endif
/** /**
* This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes:
* 1) if shapes are equal that's pairwise operation, result will have the same shape. * 1) if shapes are equal that's pairwise operation, result will have the same shape.
@ -14386,6 +14430,54 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
} }
// #endif // #endif
/**
* Broadcastable igamma implementation
*
* igamma(a, x) = gamma(а, x) / Gamma(a) - Gamma distribution function P(a,x)
* Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt }
* gamma(a, x) = int from 0 to x { t ^ {a - 1} e^{-t}dt }
* \tparam T
*/
// #if NOT_EXCLUDED(OP_igamma)
@Namespace("nd4j::ops") public static class igamma extends BroadcastableOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public igamma(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public igamma(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public igamma position(long position) {
return (igamma)super.position(position);
}
public igamma() { super((Pointer)null); allocate(); }
private native void allocate();
}
// #endif
/**
* Broadcastable igammac implementation
* igammac(a, x) = Gamma(a,x)/Gamma(а) - Gamma distribution function Q(a,x)
* Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt }
* Gamma(a, x) = int from x to infinity { t ^ {a - 1} e^{-t}dt }
* \tparam T
*/
// #if NOT_EXCLUDED(OP_igammac)
@Namespace("nd4j::ops") public static class igammac extends BroadcastableOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public igammac(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public igammac(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public igammac position(long position) {
return (igammac)super.position(position);
}
public igammac() { super((Pointer)null); allocate(); }
private native void allocate();
}
// #endif
// #endif // #endif
@ -15842,6 +15934,26 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
} }
// #endif // #endif
//////////////////////////////////////////////////////////////////////////
// #if NOT_EXCLUDED(OP_lstmLayer)
@Namespace("nd4j::ops") public static class lstmLayer extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public lstmLayer(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public lstmLayer(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public lstmLayer position(long position) {
return (lstmLayer)super.position(position);
}
public lstmLayer() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
/** /**
* Implementation of operations for Simple Recurrent Unit cell: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi * Implementation of operations for Simple Recurrent Unit cell: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi
@ -17079,16 +17191,16 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* Inserts elements provided by diagonal array into the main diagonal of innermost matrices of input array * Inserts elements provided by diagonal array into the main diagonal of innermost matrices of input array
* *
* Input arrays: * Input arrays:
* input: input array, considered as batch of matrices * 0: input array, considered as batch of matrices
* diagonal: array containing elements to be inserted into input array, * 1: diagonal array containing elements to be inserted into input array,
* following rank condition should be satisfied: diagonal_rank = input_rank - 1, * following rank condition should be satisfied: diagonal_rank = input_rank - 1,
* the shapes of diagonal and input arrays must be equal except last dimension of input array, * the shapes of diagonal and input arrays must be equal except last dimension of input array,
* for example if input_shape = [A,B,C,D] then diagonal_shape = [A,B,C], * for example if input_shape = [A,B,C,D] then diagonal_shape = [A,B,C],
* also last dimension of diagonal array should be equal to smaller of last and last but one input dimensions * also last dimension of diagonal array should be equal to smaller of last and last but one input dimensions
* that is: diagonal_shape[-1] = min(input_shape[-1], input_shape[-2]) * that is: diagonal_shape[-1] = min(input_shape[-1], input_shape[-2])
* *
* Output array: * Output array:
* has the same shape as input, corresponding diagonal elements are substituted * 0: has the same shape as input, corresponding diagonal elements are substituted
*/ */
// #if NOT_EXCLUDED(OP_matrix_set_diag) // #if NOT_EXCLUDED(OP_matrix_set_diag)
@Namespace("nd4j::ops") public static class matrix_set_diag extends DeclarableOp { @Namespace("nd4j::ops") public static class matrix_set_diag extends DeclarableOp {
@ -17109,8 +17221,16 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
// #endif // #endif
/** /**
* Returns a batched matrix tensor with diagonal values given (as TF.matrix_diag). * Inserts elements provided by diagonal array into the main diagonal of innermost matrices of output array,
*/ * rest output elements are set to zeros
*
* Input array:
* diagonal: array containing elements to be inserted into output array,
* following rank condition is present: diagonal_rank = ouput_rank - 1
*
* Output array:
* 0: is considered as batch of matrices, if for example diagonal array has shape [A,B,C] then output array has shape [A,B,C,C]
*/
@Namespace("nd4j::ops") public static class matrix_diag extends DeclarableCustomOp { @Namespace("nd4j::ops") public static class matrix_diag extends DeclarableCustomOp {
static { Loader.load(); } static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
@ -17130,13 +17250,13 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
/** /**
* This op calculates regularized incomplete beta integral Ix(a, b). * This op calculates regularized incomplete beta integral Ix(a, b).
* Implementation is based on two algorithms depending on input values of a and b: * Implementation is based on two algorithms depending on input values of a and b:
* - when a and b are both > maxValue (3000.), then apply Gauss-Legendre quadrature method * - when a and b are both > maxValue (3000.), then Gauss-Legendre quadrature method is applied
* - when a and b are both <= maxValue (3000.), then apply modified Lentzs algorithm for continued fractions * - when a and b are both <= maxValue (3000.), then modified Lentzs algorithm for continued fractions is applied
* *
* Input arrays: * Input arrays:
* a: define power t^{a-1}, must be > 0, type float. * a: defines power t^{a-1}, must be > 0, type float.
* b: define power (1-t)^{b-1}, must be > 0, type float. * b: defines power (1-t)^{b-1}, must be > 0, type float.
* x: define upper limit of integration, must be within (0 <= x <= 1) range, type float. * x: defines upper limit of integration, must be within (0 <= x <= 1) range, type float.
* *
* Output array: * Output array:
* 0: values of regularized incomplete beta integral that corresponds to variable upper limit x, type float * 0: values of regularized incomplete beta integral that corresponds to variable upper limit x, type float
@ -18250,6 +18370,50 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
} }
// #endif // #endif
/**
* This operation adjusts image contrast by given factor ( z = (x - mean) * factor + mean )
* Input arrays:
* 0 - input array with rank >= 3, must have last one dimension equal 3, that is dimension containing channels.
*
* T arguments:
* 0 - contrast factor
*
*/
// #if NOT_EXCLUDED(OP_adjust_contrast)
@Namespace("nd4j::ops") public static class adjust_contrast extends DeclarableOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public adjust_contrast(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public adjust_contrast(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public adjust_contrast position(long position) {
return (adjust_contrast)super.position(position);
}
public adjust_contrast() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
@Namespace("nd4j::ops") public static class adjust_contrast_v2 extends DeclarableOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public adjust_contrast_v2(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public adjust_contrast_v2(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public adjust_contrast_v2 position(long position) {
return (adjust_contrast_v2)super.position(position);
}
public adjust_contrast_v2() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
/** /**
* This operation rearranges data from depth into blocks of spatial data. This is the reverse transformation * This operation rearranges data from depth into blocks of spatial data. This is the reverse transformation
@ -19634,6 +19798,37 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
} }
// #endif // #endif
/**
* draw_bounding_boxes op - modified input image with given colors exept given boxes.
*
* input params:
* 0 - images tensor (4D) with shape {batch, width, height, channels}, where channes is 1 (BW image),
* 3 (RGB) or 4 (RGBA)
* 1 - boxes tensor (3D) with shape {batch, number_of_boxes, 4} where last dimension encoded as
* (y_min, x_min, y_max, x_max), all values in between 0. and 1.
* 2 - colours tensor (2D) with shape {number_of_boxes, channels} -- bordering color set (palette)
*
* output:
* 0 - 4D tensor with same shape as images (input 0)
*/
// #if NOT_EXCLUDED(OP_draw_bounding_boxes)
@Namespace("nd4j::ops") public static class draw_bounding_boxes extends DeclarableOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public draw_bounding_boxes(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public draw_bounding_boxes(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public draw_bounding_boxes position(long position) {
return (draw_bounding_boxes)super.position(position);
}
public draw_bounding_boxes() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
/** /**
* roll - op porting from numpy (https://docs.scipy.org/doc/numpy-1.14.0/reference/generated/numpy.roll.html) * roll - op porting from numpy (https://docs.scipy.org/doc/numpy-1.14.0/reference/generated/numpy.roll.html)
* *
@ -20623,6 +20818,67 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
} }
// #endif // #endif
/**
* fake_quant_with_min_max_vals_per_channel - tf.quantization.fake_quant_with_min_max_vars_per_channel
*
* input params:
* 0 - NDArray (input) - at least 2D.
* 1 - 1D Tensor - min values (min length equals to last dim of input)
* 2 - 1D Tensor - max value (length equals to min)
*
* int params (optional):
* 0 - num_bits (allowed interval [2, 16], default 8)
* 1 - narrow_range (default False)
*
* output:
* 0 - NDArray with the same shape as input
*/
// #if NOT_EXCLUDED(OP_fake_quant_with_min_max_vars_per_channel)
@Namespace("nd4j::ops") public static class fake_quant_with_min_max_vars_per_channel extends DeclarableOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public fake_quant_with_min_max_vars_per_channel(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public fake_quant_with_min_max_vars_per_channel(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public fake_quant_with_min_max_vars_per_channel position(long position) {
return (fake_quant_with_min_max_vars_per_channel)super.position(position);
}
public fake_quant_with_min_max_vars_per_channel() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
/**
* compare_and_bitpack - compare with greater and pack result with uint8
*
* input params:
* 0 - NDArray (input)
* 1 - 0D Tensor - threshold
*
*
* output:
* 0 - NDArray with the same shape as input and type uint8
*/
// #if NOT_EXCLUDED(OP_compare_and_bitpack)
@Namespace("nd4j::ops") public static class compare_and_bitpack extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public compare_and_bitpack(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public compare_and_bitpack(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public compare_and_bitpack position(long position) {
return (compare_and_bitpack)super.position(position);
}
public compare_and_bitpack() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
@ -23102,6 +23358,28 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
} }
// #endif // #endif
/**
* This operation change type of input and modified shape of output to conform with given data type
*
* all as above op
* */
// #if NOT_EXCLUDED(OP_bitcast)
@Namespace("nd4j::ops") public static class bitcast extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public bitcast(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public bitcast(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public bitcast position(long position) {
return (bitcast)super.position(position);
}
public bitcast() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif

View File

@ -810,7 +810,7 @@ public class CustomOpsTests extends BaseNd4jTest {
@Test @Test
public void testAdjustContrast() { public void testAdjustContrast() {
INDArray in = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 4*4*3).reshape(4,4,3); INDArray in = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 4*4*3).reshape(4,4,3);
INDArray out = Nd4j.zeros(4,4,3); INDArray out = Nd4j.zeros(DataType.DOUBLE,4, 4, 3);
INDArray expected = Nd4j.createFromArray(new double[]{-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5, INDArray expected = Nd4j.createFromArray(new double[]{-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5,
2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5, 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5,
@ -920,4 +920,15 @@ public class CustomOpsTests extends BaseNd4jTest {
assertEquals(expected, output); assertEquals(expected, output);
} }
@Test
public void testKnnMinDistance() {
INDArray point = Nd4j.rand(DataType.FLOAT, 1, 20);
INDArray lowest = Nd4j.rand(DataType.FLOAT, 1, 20);
INDArray highest = Nd4j.rand(DataType.FLOAT, 1, 20);
INDArray distance = Nd4j.scalar(0.f);
Nd4j.exec(new KnnMinDistance(point, lowest, highest, distance));
System.out.println(distance);
}
} }