parent
35e6ffede4
commit
f31661e13b
|
@ -18,6 +18,9 @@ package org.deeplearning4j.clustering.kdtree;
|
|||
|
||||
import lombok.val;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.custom.KnnMinDistance;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
|
@ -28,79 +31,103 @@ import java.util.List;
|
|||
*/
|
||||
public class HyperRect implements Serializable {
|
||||
|
||||
private List<Interval> points;
|
||||
//private List<Interval> points;
|
||||
private float[] lowerEnds;
|
||||
private float[] higherEnds;
|
||||
private INDArray lowerEndsIND;
|
||||
private INDArray higherEndsIND;
|
||||
|
||||
public HyperRect(List<Interval> points) {
|
||||
//this.points = points;
|
||||
this.points = new ArrayList<>(points.size());
|
||||
for (int i = 0; i < points.size(); ++i) {
|
||||
Interval newInterval = new Interval(points.get(i).lower, points.get(i).higher);
|
||||
this.points.add(newInterval);
|
||||
}
|
||||
public HyperRect(float[] lowerEndsIn, float[] higherEndsIn) {
|
||||
this.lowerEnds = new float[lowerEndsIn.length];
|
||||
this.higherEnds = new float[lowerEndsIn.length];
|
||||
System.arraycopy(lowerEndsIn, 0 , this.lowerEnds, 0, lowerEndsIn.length);
|
||||
System.arraycopy(higherEndsIn, 0 , this.higherEnds, 0, higherEndsIn.length);
|
||||
lowerEndsIND = Nd4j.createFromArray(lowerEnds);
|
||||
higherEndsIND = Nd4j.createFromArray(higherEnds);
|
||||
}
|
||||
|
||||
public HyperRect(float[] point) {
|
||||
this(point, point);
|
||||
}
|
||||
|
||||
public HyperRect(Pair<float[], float[]> ends) {
|
||||
this(ends.getFirst(), ends.getSecond());
|
||||
}
|
||||
|
||||
|
||||
public void enlargeTo(INDArray point) {
|
||||
for (int i = 0; i < points.size(); i++)
|
||||
points.get(i).enlarge(point.getDouble(i));
|
||||
float[] pointAsArray = point.toFloatVector();
|
||||
for (int i = 0; i < lowerEnds.length; i++) {
|
||||
float p = pointAsArray[i];
|
||||
if (lowerEnds[i] > p)
|
||||
lowerEnds[i] = p;
|
||||
else if (higherEnds[i] < p)
|
||||
higherEnds[i] = p;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public static List<Interval> point(INDArray vector) {
|
||||
List<Interval> ret = new ArrayList<>();
|
||||
public static Pair<float[],float[]> point(INDArray vector) {
|
||||
Pair<float[],float[]> ret = new Pair<>();
|
||||
float[] curr = new float[(int)vector.length()];
|
||||
for (int i = 0; i < vector.length(); i++) {
|
||||
double curr = vector.getDouble(i);
|
||||
ret.add(new Interval(curr, curr));
|
||||
curr[i] = vector.getFloat(i);
|
||||
}
|
||||
ret.setFirst(curr);
|
||||
ret.setSecond(curr);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
public List<Boolean> contains(INDArray hPoint) {
|
||||
/*public List<Boolean> contains(INDArray hPoint) {
|
||||
List<Boolean> ret = new ArrayList<>();
|
||||
for (int i = 0; i < hPoint.length(); i++)
|
||||
ret.add(points.get(i).contains(hPoint.getDouble(i)));
|
||||
return ret;
|
||||
}
|
||||
|
||||
public double minDistance(INDArray hPoint) {
|
||||
double ret = 0.0;
|
||||
for (int i = 0; i < hPoint.length(); i++) {
|
||||
double p = hPoint.getDouble(i);
|
||||
Interval interval = points.get(i);
|
||||
if (!interval.contains(p)) {
|
||||
if (p < interval.lower)
|
||||
ret += Math.pow((p - interval.lower), 2);
|
||||
else
|
||||
ret += Math.pow((p - interval.higher), 2);
|
||||
}
|
||||
ret.add(lowerEnds[i] <= hPoint.getDouble(i) &&
|
||||
higherEnds[i] >= hPoint.getDouble(i));
|
||||
}
|
||||
|
||||
ret = Math.pow(ret, 0.5);
|
||||
return ret;
|
||||
}*/
|
||||
|
||||
public double minDistance(INDArray hPoint, INDArray output) {
|
||||
Nd4j.exec(new KnnMinDistance(hPoint, lowerEndsIND, higherEndsIND, output));
|
||||
return output.getFloat(0);
|
||||
|
||||
/*double ret = 0.0;
|
||||
double[] pointAsArray = hPoint.toDoubleVector();
|
||||
for (int i = 0; i < pointAsArray.length; i++) {
|
||||
double p = pointAsArray[i];
|
||||
if (!(lowerEnds[i] <= p || higherEnds[i] <= p)) {
|
||||
if (p < lowerEnds[i])
|
||||
ret += Math.pow((p - lowerEnds[i]), 2);
|
||||
else
|
||||
ret += Math.pow((p - higherEnds[i]), 2);
|
||||
}
|
||||
}
|
||||
ret = Math.pow(ret, 0.5);
|
||||
return ret;*/
|
||||
}
|
||||
|
||||
public HyperRect getUpper(INDArray hPoint, int desc) {
|
||||
Interval interval = points.get(desc);
|
||||
double d = hPoint.getDouble(desc);
|
||||
if (interval.higher < d)
|
||||
//Interval interval = points.get(desc);
|
||||
float higher = higherEnds[desc];
|
||||
float d = hPoint.getFloat(desc);
|
||||
if (higher < d)
|
||||
return null;
|
||||
HyperRect ret = new HyperRect(new ArrayList<>(points));
|
||||
Interval i2 = ret.points.get(desc);
|
||||
if (i2.lower < d)
|
||||
i2.lower = d;
|
||||
HyperRect ret = new HyperRect(lowerEnds,higherEnds);
|
||||
if (ret.lowerEnds[desc] < d)
|
||||
ret.lowerEnds[desc] = d;
|
||||
return ret;
|
||||
}
|
||||
|
||||
public HyperRect getLower(INDArray hPoint, int desc) {
|
||||
Interval interval = points.get(desc);
|
||||
double d = hPoint.getDouble(desc);
|
||||
if (interval.lower > d)
|
||||
//Interval interval = points.get(desc);
|
||||
float lower = lowerEnds[desc];
|
||||
float d = hPoint.getFloat(desc);
|
||||
if (lower > d)
|
||||
return null;
|
||||
HyperRect ret = new HyperRect(new ArrayList<>(points));
|
||||
Interval i2 = ret.points.get(desc);
|
||||
if (i2.higher > d)
|
||||
i2.higher = d;
|
||||
HyperRect ret = new HyperRect(lowerEnds,higherEnds);
|
||||
//Interval i2 = ret.points.get(desc);
|
||||
if (ret.higherEnds[desc] > d)
|
||||
ret.higherEnds[desc] = d;
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -108,33 +135,10 @@ public class HyperRect implements Serializable {
|
|||
public String toString() {
|
||||
String retVal = "";
|
||||
retVal += "[";
|
||||
for (val point : points) {
|
||||
retVal += "(" + point.lower + " - " + point.higher + ") ";
|
||||
for (int i = 0; i < lowerEnds.length; ++i) {
|
||||
retVal += "(" + lowerEnds[i] + " - " + higherEnds[i] + ") ";
|
||||
}
|
||||
retVal += "]";
|
||||
return retVal;
|
||||
}
|
||||
|
||||
public static class Interval {
|
||||
private double lower, higher;
|
||||
|
||||
public Interval(double lower, double higher) {
|
||||
this.lower = lower;
|
||||
this.higher = higher;
|
||||
}
|
||||
|
||||
public boolean contains(double point) {
|
||||
return lower <= point || point <= higher;
|
||||
|
||||
}
|
||||
|
||||
public void enlarge(double p) {
|
||||
if (lower > p)
|
||||
lower = p;
|
||||
else if (higher < p)
|
||||
higher = p;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -56,7 +56,7 @@ public class KDTree implements Serializable {
|
|||
|
||||
if (root == null) {
|
||||
root = new KDNode(point);
|
||||
rect = new HyperRect(HyperRect.point(point));
|
||||
rect = new HyperRect(/*HyperRect.point(point)*/ point.toFloatVector());
|
||||
} else {
|
||||
int disc = 0;
|
||||
KDNode node = root;
|
||||
|
@ -125,15 +125,21 @@ public class KDTree implements Serializable {
|
|||
return node.getPoint();
|
||||
}
|
||||
|
||||
// Share this data for recursive calls of "knn"
|
||||
private float currentDistance;
|
||||
private INDArray currentPoint;
|
||||
private INDArray minDistance = Nd4j.scalar(0.f);
|
||||
|
||||
|
||||
public List<Pair<Double, INDArray>> knn(INDArray point, double distance) {
|
||||
List<Pair<Double, INDArray>> best = new ArrayList<>();
|
||||
knn(root, point, rect, distance, best, 0);
|
||||
Collections.sort(best, new Comparator<Pair<Double, INDArray>>() {
|
||||
public List<Pair<Float, INDArray>> knn(INDArray point, float distance) {
|
||||
List<Pair<Float, INDArray>> best = new ArrayList<>();
|
||||
currentDistance = distance;
|
||||
currentPoint = point;
|
||||
knn(root, rect, best, 0);
|
||||
Collections.sort(best, new Comparator<Pair<Float, INDArray>>() {
|
||||
@Override
|
||||
public int compare(Pair<Double, INDArray> o1, Pair<Double, INDArray> o2) {
|
||||
return Double.compare(o1.getKey(), o2.getKey());
|
||||
public int compare(Pair<Float, INDArray> o1, Pair<Float, INDArray> o2) {
|
||||
return Float.compare(o1.getKey(), o2.getKey());
|
||||
}
|
||||
});
|
||||
|
||||
|
@ -141,22 +147,21 @@ public class KDTree implements Serializable {
|
|||
}
|
||||
|
||||
|
||||
private void knn(KDNode node, INDArray point, HyperRect rect, double dist, List<Pair<Double, INDArray>> best,
|
||||
int _disc) {
|
||||
if (node == null || rect == null || rect.minDistance(point) > dist)
|
||||
private void knn(KDNode node, HyperRect rect, List<Pair<Float, INDArray>> best, int _disc) {
|
||||
if (node == null || rect == null || rect.minDistance(currentPoint, minDistance) > currentDistance)
|
||||
return;
|
||||
int _discNext = (_disc + 1) % dims;
|
||||
double distance = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(point,node.point)).getFinalResult()
|
||||
.doubleValue();
|
||||
float distance = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(currentPoint,node.point, minDistance)).getFinalResult()
|
||||
.floatValue();
|
||||
|
||||
if (distance <= dist) {
|
||||
if (distance <= currentDistance) {
|
||||
best.add(Pair.of(distance, node.getPoint()));
|
||||
}
|
||||
|
||||
HyperRect lower = rect.getLower(node.point, _disc);
|
||||
HyperRect upper = rect.getUpper(node.point, _disc);
|
||||
knn(node.getLeft(), point, lower, dist, best, _discNext);
|
||||
knn(node.getRight(), point, upper, dist, best, _discNext);
|
||||
knn(node.getLeft(), lower, best, _discNext);
|
||||
knn(node.getRight(), upper, best, _discNext);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -171,7 +176,7 @@ public class KDTree implements Serializable {
|
|||
|
||||
private Pair<Double, INDArray> nn(KDNode node, INDArray point, HyperRect rect, double dist, INDArray best,
|
||||
int _disc) {
|
||||
if (node == null || rect.minDistance(point) > dist)
|
||||
if (node == null || rect.minDistance(point, minDistance) > dist)
|
||||
return Pair.of(Double.POSITIVE_INFINITY, null);
|
||||
|
||||
int _discNext = (_disc + 1) % dims;
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
package org.deeplearning4j.clustering.kdtree;
|
||||
|
||||
import org.joda.time.Instant;
|
||||
import org.nd4j.shade.guava.base.Stopwatch;
|
||||
import org.nd4j.shade.guava.primitives.Doubles;
|
||||
import lombok.val;
|
||||
import org.deeplearning4j.clustering.BaseDL4JTest;
|
||||
|
@ -28,6 +30,7 @@ import org.nd4j.linalg.api.buffer.DataType;
|
|||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.shade.guava.primitives.Floats;
|
||||
import org.opencv.ml.KNearest;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
@ -35,6 +38,8 @@ import java.util.Arrays;
|
|||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
import static java.util.concurrent.TimeUnit.MILLISECONDS;
|
||||
import static java.util.concurrent.TimeUnit.SECONDS;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
|
@ -53,17 +58,17 @@ public class KDTreeTest extends BaseDL4JTest {
|
|||
@Before
|
||||
public void setUp() {
|
||||
kdTree = new KDTree(2);
|
||||
double[] data = new double[]{7,2};
|
||||
float[] data = new float[]{7,2};
|
||||
kdTree.insert(Nd4j.createFromArray(data));
|
||||
data = new double[]{5,4};
|
||||
data = new float[]{5,4};
|
||||
kdTree.insert(Nd4j.createFromArray(data));
|
||||
data = new double[]{2,3};
|
||||
data = new float[]{2,3};
|
||||
kdTree.insert(Nd4j.createFromArray(data));
|
||||
data = new double[]{4,7};
|
||||
data = new float[]{4,7};
|
||||
kdTree.insert(Nd4j.createFromArray(data));
|
||||
data = new double[]{9,6};
|
||||
data = new float[]{9,6};
|
||||
kdTree.insert(Nd4j.createFromArray(data));
|
||||
data = new double[]{8,1};
|
||||
data = new float[]{8,1};
|
||||
kdTree.insert(Nd4j.createFromArray(data));
|
||||
}
|
||||
|
||||
|
@ -168,26 +173,30 @@ public class KDTreeTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
public void testKNN() {
|
||||
int n = 10;
|
||||
// make a KD-tree of dimension {#n}
|
||||
KDTree kdTree = new KDTree(n);
|
||||
for (int i = -1; i < n; i++) {
|
||||
int dimensions = 512;
|
||||
int vectorsNo = 50000;
|
||||
// make a KD-tree of dimension {#dimensions}
|
||||
Stopwatch stopwatch = Stopwatch.createStarted();
|
||||
KDTree kdTree = new KDTree(dimensions);
|
||||
for (int i = -1; i < vectorsNo; i++) {
|
||||
// Insert a unit vector along each dimension
|
||||
List<Double> vec = new ArrayList<>(n);
|
||||
// i = -1 ensures the origin is in the Tree
|
||||
for (int k = 0; k < n; k++) {
|
||||
vec.add((k == i) ? 1.0 : 0.0);
|
||||
}
|
||||
INDArray indVec = Nd4j.create(Nd4j.createBuffer(Doubles.toArray(vec)));
|
||||
INDArray indVec = Nd4j.rand(DataType.FLOAT, 1,dimensions);
|
||||
kdTree.insert(indVec);
|
||||
}
|
||||
stopwatch.stop();
|
||||
System.out.println("Time elapsed for " + kdTree.size() + " nodes construction is "+ stopwatch.elapsed(SECONDS));
|
||||
|
||||
Random rand = new Random();
|
||||
// random point in the Hypercube
|
||||
List<Double> pt = new ArrayList(n);
|
||||
for (int k = 0; k < n; k++) {
|
||||
pt.add(rand.nextDouble() * 10.0);
|
||||
List<Double> pt = new ArrayList(dimensions);
|
||||
for (int k = 0; k < dimensions; k++) {
|
||||
pt.add(rand.nextFloat() * 10.0);
|
||||
}
|
||||
List<Pair<Double, INDArray>> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Doubles.toArray(pt))), 20.0);
|
||||
stopwatch.reset();
|
||||
stopwatch.start();
|
||||
List<Pair<Float, INDArray>> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Floats.toArray(pt))), 20.0f);
|
||||
stopwatch.stop();
|
||||
System.out.println("Time elapsed for Search is "+ stopwatch.elapsed(MILLISECONDS));
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -195,15 +204,15 @@ public class KDTreeTest extends BaseDL4JTest {
|
|||
int n = 2;
|
||||
KDTree kdTree = new KDTree(n);
|
||||
|
||||
double[] data = new double[]{3,3};
|
||||
float[] data = new float[]{3,3};
|
||||
kdTree.insert(Nd4j.createFromArray(data));
|
||||
data = new double[]{1,1};
|
||||
data = new float[]{1,1};
|
||||
kdTree.insert(Nd4j.createFromArray(data));
|
||||
data = new double[]{2,2};
|
||||
data = new float[]{2,2};
|
||||
kdTree.insert(Nd4j.createFromArray(data));
|
||||
|
||||
data = new double[]{0,0};
|
||||
List<Pair<Double, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 4.5);
|
||||
data = new float[]{0,0};
|
||||
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 4.5f);
|
||||
|
||||
assertEquals(1.0, result.get(0).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(1.0, result.get(0).getSecond().getDouble(1), 1e-5);
|
||||
|
@ -220,88 +229,88 @@ public class KDTreeTest extends BaseDL4JTest {
|
|||
|
||||
assertEquals(6, kdTree.size());
|
||||
|
||||
double[] data = new double[]{8,1};
|
||||
List<Pair<Double, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 10.0);
|
||||
assertEquals(8.0, result.get(0).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(1.0, result.get(0).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(7.0, result.get(1).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(2.0, result.get(1).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(5.0, result.get(2).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(4.0, result.get(2).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(9.0, result.get(3).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(6.0, result.get(3).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(2.0, result.get(4).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(3.0, result.get(4).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(4.0, result.get(5).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(7.0, result.get(5).getSecond().getDouble(1), 1e-5);
|
||||
float[] data = new float[]{8,1};
|
||||
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f);
|
||||
assertEquals(8.0, result.get(0).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(1.0, result.get(0).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(7.0, result.get(1).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(2.0, result.get(1).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(5.0, result.get(2).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(4.0, result.get(2).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(9.0, result.get(3).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(6.0, result.get(3).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(2.0, result.get(4).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(3.0, result.get(4).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(4.0, result.get(5).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(7.0, result.get(5).getSecond().getFloat(1), 1e-5);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testKNN_2() {
|
||||
double[] data = new double[]{8, 1};
|
||||
List<Pair<Double, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 5.0);
|
||||
assertEquals(8.0, result.get(0).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(1.0, result.get(0).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(7.0, result.get(1).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(2.0, result.get(1).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(5.0, result.get(2).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(4.0, result.get(2).getSecond().getDouble(1), 1e-5);
|
||||
float[] data = new float[]{8, 1};
|
||||
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f);
|
||||
assertEquals(8.0, result.get(0).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(1.0, result.get(0).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(7.0, result.get(1).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(2.0, result.get(1).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(5.0, result.get(2).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(4.0, result.get(2).getSecond().getFloat(1), 1e-5);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testKNN_3() {
|
||||
|
||||
double[] data = new double[]{2, 3};
|
||||
val result = kdTree.knn(Nd4j.createFromArray(data), 10.0);
|
||||
assertEquals(2.0, result.get(0).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(3.0, result.get(0).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(4.0, result.get(2).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(7.0, result.get(2).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(7.0, result.get(3).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(2.0, result.get(3).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(8.0, result.get(4).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(1.0, result.get(4).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(9.0, result.get(5).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(6.0, result.get(5).getSecond().getDouble(1), 1e-5);
|
||||
float[] data = new float[]{2, 3};
|
||||
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f);
|
||||
assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(7.0, result.get(3).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(2.0, result.get(3).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(8.0, result.get(4).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(1.0, result.get(4).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(9.0, result.get(5).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(6.0, result.get(5).getSecond().getFloat(1), 1e-5);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testKNN_4() {
|
||||
double[] data = new double[]{2, 3};
|
||||
val result = kdTree.knn(Nd4j.createFromArray(data), 5.0);
|
||||
assertEquals(2.0, result.get(0).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(3.0, result.get(0).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(4.0, result.get(2).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(7.0, result.get(2).getSecond().getDouble(1), 1e-5);
|
||||
float[] data = new float[]{2, 3};
|
||||
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f);
|
||||
assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testKNN_5() {
|
||||
double[] data = new double[]{2, 3};
|
||||
val result = kdTree.knn(Nd4j.createFromArray(data), 20.0);
|
||||
assertEquals(2.0, result.get(0).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(3.0, result.get(0).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(4.0, result.get(2).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(7.0, result.get(2).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(7.0, result.get(3).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(2.0, result.get(3).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(8.0, result.get(4).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(1.0, result.get(4).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(9.0, result.get(5).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(6.0, result.get(5).getSecond().getDouble(1), 1e-5);
|
||||
float[] data = new float[]{2, 3};
|
||||
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 20.0f);
|
||||
assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(7.0, result.get(3).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(2.0, result.get(3).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(8.0, result.get(4).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(1.0, result.get(4).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(9.0, result.get(5).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(6.0, result.get(5).getSecond().getFloat(1), 1e-5);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_KNN_6() {
|
||||
double[] data = new double[]{4, 6};
|
||||
val result = kdTree.knn(Nd4j.createFromArray(data), 10.0);
|
||||
float[] data = new float[]{4, 6};
|
||||
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f);
|
||||
assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5);
|
||||
|
@ -318,8 +327,8 @@ public class KDTreeTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
public void test_KNN_7() {
|
||||
double[] data = new double[]{4, 6};
|
||||
val result = kdTree.knn(Nd4j.createFromArray(data), 5.0);
|
||||
float[] data = new float[]{4, 6};
|
||||
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f);
|
||||
assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5);
|
||||
|
@ -334,8 +343,8 @@ public class KDTreeTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
public void test_KNN_8() {
|
||||
double[] data = new double[]{4, 6};
|
||||
val result = kdTree.knn(Nd4j.createFromArray(data), 20.0);
|
||||
float[] data = new float[]{4, 6};
|
||||
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 20.0f);
|
||||
assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5);
|
||||
|
@ -392,12 +401,12 @@ public class KDTreeTest extends BaseDL4JTest {
|
|||
Duration duration = new Duration(start, end);
|
||||
System.out.println("Elapsed time for tree construction " + duration.getStandardSeconds() + " " + duration.getMillis());
|
||||
|
||||
List<Double> pt = new ArrayList(num);
|
||||
List<Float> pt = new ArrayList(num);
|
||||
for (int k = 0; k < n; k++) {
|
||||
pt.add((double)(num / 2));
|
||||
pt.add((float)(num / 2));
|
||||
}
|
||||
start = System.currentTimeMillis();
|
||||
List<Pair<Double, INDArray>> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Doubles.toArray(pt))), 20.0);
|
||||
List<Pair<Float, INDArray>> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Doubles.toArray(pt))), 20.0f);
|
||||
end = System.currentTimeMillis();
|
||||
duration = new Duration(start, end);
|
||||
long elapsed = end - start;
|
||||
|
|
|
@ -39,6 +39,7 @@
|
|||
#include <ops/declarable/headers/datatypes.h>
|
||||
#include <ops/declarable/headers/third_party.h>
|
||||
#include <ops/declarable/headers/tests.h>
|
||||
#include <ops/declarable/headers/kernels.h>
|
||||
#include <ops/declarable/headers/BarnesHutTsne.h>
|
||||
#include <dll.h>
|
||||
#include <helpers/shape.h>
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
#if NOT_EXCLUDED(OP_knn_mindistance)
|
||||
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <ops/declarable/helpers/knn.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
CUSTOM_OP_IMPL(knn_mindistance, 3, 1, false, 0, 0) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto lowest = INPUT_VARIABLE(1);
|
||||
auto highest = INPUT_VARIABLE(2);
|
||||
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
REQUIRE_TRUE(input->lengthOf() == lowest->lengthOf() && input->lengthOf() == highest->lengthOf(), 0, "knn_mindistance: all input arrays must have same length");
|
||||
REQUIRE_TRUE(input->dataType() == lowest->dataType() && input->dataType() == highest->dataType() && input->dataType() == output->dataType(), 0, "knn_mindistance: all inputs must have the same data type");
|
||||
|
||||
helpers::knn_mindistance(*input, *lowest, *highest, *output);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(knn_mindistance) {
|
||||
auto input = inputShape->at(0);
|
||||
|
||||
// always return scalar here
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(input)));
|
||||
}
|
||||
|
||||
DECLARE_TYPES(knn_mindistance) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes({ALL_FLOATS})
|
||||
->setAllowedOutputTypes({ALL_FLOATS});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -0,0 +1,34 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#ifndef LIBND4J_KERNELS_H
|
||||
#define LIBND4J_KERNELS_H
|
||||
|
||||
#include <ops/declarable/headers/common.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
#if NOT_EXCLUDED(OP_knn_mindistance)
|
||||
DECLARE_CUSTOM_OP(knn_mindistance, 3, 1, false, 0, 0);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
#endif //LIBND4J_KERNELS_H
|
|
@ -0,0 +1,62 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <ops/declarable/helpers/knn.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
template <typename T>
|
||||
void mindistance_(const void* vinput, const void *vlow, const void *vhigh, int32_t length, void *vout) {
|
||||
auto input = reinterpret_cast<const T*>(vinput);
|
||||
auto low = reinterpret_cast<const T*>(vlow);
|
||||
auto high = reinterpret_cast<const T*>(vhigh);
|
||||
auto output = reinterpret_cast<T*>(vout);
|
||||
|
||||
T res = 0.0f;
|
||||
T po = 2.f;
|
||||
T o = 1.f;
|
||||
|
||||
#pragma omp simd reduction(sumT:res)
|
||||
for (auto e = 0; e < length; e++) {
|
||||
T p = input[e];
|
||||
T l = low[e];
|
||||
T h = high[e];
|
||||
if (!(l <= p || h <= p)) {
|
||||
if (p < l)
|
||||
res += nd4j::math::nd4j_pow<T, T, T>((p - o), po);
|
||||
else
|
||||
res += nd4j::math::nd4j_pow<T, T, T>((p - h), po);
|
||||
}
|
||||
}
|
||||
|
||||
output[0] = nd4j::math::nd4j_pow<T, T, T>(res, (T) 0.5f);
|
||||
}
|
||||
|
||||
void knn_mindistance(const NDArray &input, const NDArray &lowest, const NDArray &highest, NDArray &output) {
|
||||
NDArray::preparePrimaryUse({&output}, {&input, &lowest, &highest});
|
||||
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), mindistance_, (input.getBuffer(), lowest.getBuffer(), highest.getBuffer(), input.lengthOf(), output.buffer()), FLOAT_TYPES);
|
||||
|
||||
NDArray::registerPrimaryUse({&output}, {&input, &lowest, &highest});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#ifndef SAMEDIFF_KNN_H
|
||||
#define SAMEDIFF_KNN_H
|
||||
|
||||
#include <ops/declarable/helpers/helpers.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
void knn_mindistance(const NDArray &input, const NDArray &lowest, const NDArray &highest, NDArray &output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif //SAMEDIFF_KNN_H
|
|
@ -217,7 +217,7 @@ namespace nd4j {
|
|||
auto var = ctx.variable(pair);
|
||||
auto shape = var->getNDArray()->shapeInfo();
|
||||
|
||||
if (!shape::equalsSoft(out, shape)) {
|
||||
if (!shape::equalsSoft(out, shape) || shape::isEmpty(out) != shape::isEmpty(shape)) {
|
||||
auto eShape = ShapeUtils::shapeAsString(out);
|
||||
auto aShape = ShapeUtils::shapeAsString(shape);
|
||||
|
||||
|
@ -237,7 +237,7 @@ namespace nd4j {
|
|||
ctx.setOutputArray(idx, outArr, true);
|
||||
} else {
|
||||
auto array = fout[idx];
|
||||
if (!shape::equalsSoft(out, array->shapeInfo())) {
|
||||
if (!shape::equalsSoft(out, array->shapeInfo()) || shape::isEmpty(out) != array->isEmpty()) {
|
||||
auto eShape = ShapeUtils::shapeAsString(out);
|
||||
auto aShape = ShapeUtils::shapeAsString(array->shapeInfo());
|
||||
|
||||
|
|
|
@ -134,3 +134,20 @@ TEST_F(DeclarableOpsTests16, test_hamming_distance_1) {
|
|||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests16, test_knn_mindistance_1) {
|
||||
auto input = NDArrayFactory::create<float>('c', {512});
|
||||
auto low = NDArrayFactory::create<float>('c', {512});
|
||||
auto high = NDArrayFactory::create<float>('c', {512});
|
||||
|
||||
auto output = NDArrayFactory::create<float>(0.0f);
|
||||
|
||||
input.linspace(1.0);
|
||||
low.linspace(1.0);
|
||||
high.linspace(1.0);
|
||||
|
||||
nd4j::ops::knn_mindistance op;
|
||||
auto result = op.execute({&input, &low, &high}, {&output}, {}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result);
|
||||
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
package org.nd4j.linalg.api.ops.custom;
|
||||
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
public class KnnMinDistance extends DynamicCustomOp {
|
||||
|
||||
public KnnMinDistance() {
|
||||
}
|
||||
|
||||
public KnnMinDistance(INDArray point, INDArray lowest, INDArray highest, INDArray distance) {
|
||||
inputArguments.add(point);
|
||||
inputArguments.add(lowest);
|
||||
inputArguments.add(highest);
|
||||
|
||||
outputArguments.add(distance);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "knn_mindistance";
|
||||
}
|
||||
}
|
|
@ -52,20 +52,26 @@ public class CudaOpContext extends BaseOpContext implements OpContext {
|
|||
|
||||
@Override
|
||||
public void setIArguments(long... arguments) {
|
||||
super.setIArguments(arguments);
|
||||
nativeOps.setGraphContextIArguments(context, new LongPointer(arguments), arguments.length);
|
||||
if (arguments.length > 0) {
|
||||
super.setIArguments(arguments);
|
||||
nativeOps.setGraphContextIArguments(context, new LongPointer(arguments), arguments.length);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setBArguments(boolean... arguments) {
|
||||
super.setBArguments(arguments);
|
||||
nativeOps.setGraphContextBArguments(context, new BooleanPointer(arguments), arguments.length);
|
||||
if (arguments.length > 0) {
|
||||
super.setBArguments(arguments);
|
||||
nativeOps.setGraphContextBArguments(context, new BooleanPointer(arguments), arguments.length);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setTArguments(double... arguments) {
|
||||
super.setTArguments(arguments);
|
||||
nativeOps.setGraphContextTArguments(context, new DoublePointer(arguments), arguments.length);
|
||||
if (arguments.length > 0) {
|
||||
super.setTArguments(arguments);
|
||||
nativeOps.setGraphContextTArguments(context, new DoublePointer(arguments), arguments.length);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -49,20 +49,26 @@ public class CpuOpContext extends BaseOpContext implements OpContext {
|
|||
|
||||
@Override
|
||||
public void setIArguments(long... arguments) {
|
||||
super.setIArguments(arguments);
|
||||
nativeOps.setGraphContextIArguments(context, new LongPointer(arguments), arguments.length);
|
||||
if (arguments.length > 0) {
|
||||
super.setIArguments(arguments);
|
||||
nativeOps.setGraphContextIArguments(context, new LongPointer(arguments), arguments.length);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setBArguments(boolean... arguments) {
|
||||
super.setBArguments(arguments);
|
||||
nativeOps.setGraphContextBArguments(context, new BooleanPointer(arguments), arguments.length);
|
||||
if (arguments.length > 0) {
|
||||
super.setBArguments(arguments);
|
||||
nativeOps.setGraphContextBArguments(context, new BooleanPointer(arguments), arguments.length);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setTArguments(double... arguments) {
|
||||
super.setTArguments(arguments);
|
||||
nativeOps.setGraphContextTArguments(context, new DoublePointer(arguments), arguments.length);
|
||||
if (arguments.length > 0) {
|
||||
super.setTArguments(arguments);
|
||||
nativeOps.setGraphContextTArguments(context, new DoublePointer(arguments), arguments.length);
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// Targeted by JavaCPP version 1.5.2-SNAPSHOT: DO NOT EDIT THIS FILE
|
||||
// Targeted by JavaCPP version 1.5.1-1: DO NOT EDIT THIS FILE
|
||||
|
||||
package org.nd4j.nativeblas;
|
||||
|
||||
|
@ -4703,6 +4703,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
|
|||
* k - depth
|
||||
* value - scalar value to assign
|
||||
*/
|
||||
public native void p(@Cast("const Nd4jLong") long i, @Cast("const Nd4jLong") long j, @Cast("const Nd4jLong") long k, @Cast("const Nd4jLong") long l, @Const @ByRef NDArray value);
|
||||
|
||||
/**
|
||||
* creates array which points on certain sub-range of this array, sub-range is defined by given indices
|
||||
|
@ -4931,7 +4932,7 @@ NDArray NDArray::operator()(const Nd4jLong i) const {
|
|||
} else {
|
||||
Nd4jLong idx[MAX_RANK];
|
||||
shape::ind2subC(rankOf(), shapeOf(), i, idx);
|
||||
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), idx, rankOf());
|
||||
auto xOffset = shape::getOffset(getShapeInfo(), idx);
|
||||
|
||||
auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT());
|
||||
NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace()));
|
||||
|
@ -4962,7 +4963,7 @@ NDArray& NDArray::operator()(const Nd4jLong i) {
|
|||
} else {
|
||||
Nd4jLong idx[MAX_RANK];
|
||||
shape::ind2subC(rankOf(), shapeOf(), i, idx);
|
||||
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), idx, rankOf());
|
||||
auto xOffset = shape::getOffset(getShapeInfo(), idx);
|
||||
|
||||
auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT());
|
||||
NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace()));
|
||||
|
@ -4979,7 +4980,7 @@ NDArray NDArray::operator()(const Nd4jLong i, const Nd4jLong j) const {
|
|||
throw std::invalid_argument("NDArray::operator(i,j): one of input indexes is out of array length or rank!=2 !");
|
||||
|
||||
Nd4jLong coords[2] = {i, j};
|
||||
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
|
||||
auto xOffset = shape::getOffset(getShapeInfo(), coords);
|
||||
|
||||
// TODO: do we really want a view here?
|
||||
auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT());
|
||||
|
@ -4995,7 +4996,7 @@ NDArray& NDArray::operator()(const Nd4jLong i, const Nd4jLong j) {
|
|||
throw std::invalid_argument("NDArray::operator(i,j): one of input indexes is out of array length or rank!=2 !");
|
||||
|
||||
Nd4jLong coords[2] = {i, j};
|
||||
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
|
||||
auto xOffset = shape::getOffset(getShapeInfo(), coords);
|
||||
|
||||
auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT());
|
||||
NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace()));
|
||||
|
@ -5014,7 +5015,7 @@ NDArray NDArray::operator()(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k
|
|||
throw std::invalid_argument("NDArray::operator(i,j,k): one of input indexes is out of array length or rank!=3 !");
|
||||
|
||||
Nd4jLong coords[3] = {i, j, k};
|
||||
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
|
||||
auto xOffset = shape::getOffset(getShapeInfo(), coords);
|
||||
|
||||
auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT());
|
||||
NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace()));
|
||||
|
@ -5031,7 +5032,7 @@ NDArray& NDArray::operator()(const Nd4jLong i, const Nd4jLong j, const Nd4jLong
|
|||
throw std::invalid_argument("NDArray::operator(i,j,k): one of input indexes is out of array length or rank!=3 !");
|
||||
|
||||
Nd4jLong coords[3] = {i, j, k};
|
||||
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
|
||||
auto xOffset = shape::getOffset(getShapeInfo(), coords);
|
||||
|
||||
auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT());
|
||||
NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace()));
|
||||
|
@ -5047,7 +5048,7 @@ NDArray NDArray::operator()(const Nd4jLong t, const Nd4jLong u, const Nd4jLong v
|
|||
throw std::invalid_argument("NDArray::operator(t,u,v,w): one of input indexes is out of array length or rank!=4 !");
|
||||
|
||||
Nd4jLong coords[4] = {t, u, v, w};
|
||||
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
|
||||
auto xOffset = shape::getOffset(getShapeInfo(), coords);
|
||||
|
||||
auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT());
|
||||
NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace()));
|
||||
|
@ -5061,7 +5062,7 @@ NDArray& NDArray::operator()(const Nd4jLong t, const Nd4jLong u, const Nd4jLong
|
|||
throw std::invalid_argument("NDArray::operator(t,u,v,w): one of input indexes is out of array length or rank!=4 !");
|
||||
|
||||
Nd4jLong coords[4] = {t, u, v, w};
|
||||
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
|
||||
auto xOffset = shape::getOffset(getShapeInfo(), coords);
|
||||
|
||||
// FIXME
|
||||
auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT());
|
||||
|
@ -5077,7 +5078,7 @@ NDArray NDArray::operator()(const Nd4jLong* idx) const {
|
|||
if (idx[i] >= sizeAt(i))
|
||||
throw std::invalid_argument("NDArray::operator(const Nd4jLong* idx): input index is out of dimension length !");
|
||||
|
||||
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), idx, rankOf());
|
||||
auto xOffset = shape::getOffset(getShapeInfo(), idx);
|
||||
|
||||
auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT());
|
||||
NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace()));
|
||||
|
@ -5092,7 +5093,7 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) {
|
|||
if (idx[i] >= sizeAt(i))
|
||||
throw std::invalid_argument("NDArray::operator(const Nd4jLong* idx): input index is out of dimension length !");
|
||||
|
||||
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), idx, rankOf());
|
||||
auto xOffset = shape::getOffset(getShapeInfo(), idx);
|
||||
|
||||
auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT());
|
||||
NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace()));
|
||||
|
@ -7958,9 +7959,7 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
* @param indices the indices to iterate over
|
||||
* @return the double at the specified index
|
||||
*/
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer stride, @Cast("const Nd4jLong*") LongPointer indices, int rank);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer stride, @Cast("const Nd4jLong*") LongBuffer indices, int rank);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] stride, @Cast("const Nd4jLong*") long[] indices, int rank);
|
||||
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer indices);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
|
||||
|
@ -7981,34 +7980,26 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
|
||||
/**
|
||||
* Convert a linear index to the corresponding coordinates
|
||||
* for example if shape is {2, 4}, then index 5 corresponds to following coordinates
|
||||
* -> [1, 1] in case of c order
|
||||
* -> [1, 2] in case of f order
|
||||
* for example if shape is {2, 4}, then index 5 corresponds to coordinates [1, 1]
|
||||
*/
|
||||
@Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong") long arrLen, @Cast("Nd4jLong*") LongPointer coords, byte order/*='c'*/);
|
||||
@Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong") long arrLen, @Cast("Nd4jLong*") LongPointer coords);
|
||||
@Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong") long arrLen, @Cast("Nd4jLong*") LongBuffer coords, byte order/*='c'*/);
|
||||
@Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong") long arrLen, @Cast("Nd4jLong*") LongBuffer coords);
|
||||
@Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong") long arrLen, @Cast("Nd4jLong*") long[] coords, byte order/*='c'*/);
|
||||
@Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong") long arrLen, @Cast("Nd4jLong*") long[] coords);
|
||||
@Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong*") LongPointer coords, byte order/*='c'*/);
|
||||
@Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong*") LongPointer coords);
|
||||
@Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong*") LongBuffer coords, byte order/*='c'*/);
|
||||
@Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong*") LongBuffer coords);
|
||||
@Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong*") long[] coords, byte order/*='c'*/);
|
||||
@Namespace("shape") public static native void index2coords(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong") long index, @Cast("Nd4jLong*") long[] coords);
|
||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords);
|
||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords);
|
||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords);
|
||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer coords);
|
||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer coords);
|
||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] coords);
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Convert coordinates to the corresponding linear index (sequence number in other words)
|
||||
* for example if shape is {2, 4}, then:
|
||||
* in case of c order and coordinates [1, 1] index 5 is returned
|
||||
* in case of f order and coordinates [1, 2] index 5 is returned
|
||||
* for example if shape is {2, 4} and coordinates [1, 1] then index 5 is returned
|
||||
*/
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords, byte order/*='c'*/);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords, byte order/*='c'*/);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords, byte order/*='c'*/);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords);
|
||||
|
||||
/**
|
||||
|
@ -8020,36 +8011,16 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
*/
|
||||
|
||||
/* calculates an array buffer offset for given "index" using following formula: offset = coord_0*stride_0 + coord_1*stride_1 + ... + coord_{rank-1}*stride_{rank-1}
|
||||
* arrLen - array length
|
||||
*/
|
||||
@Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") IntPointer shapeInfo, @Cast("uint") int arrLen);
|
||||
@Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") IntBuffer shapeInfo, @Cast("uint") int arrLen);
|
||||
@Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") int[] shapeInfo, @Cast("uint") int arrLen);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong") long arrLen);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong") long arrLen);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong") long arrLen);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOrderOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong") long arrLen, byte order);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOrderOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong") long arrLen, byte order);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOrderOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong") long arrLen, byte order);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer lShapeInfo, @Cast("const uint*") IntPointer uShapeInfo, @Cast("Nd4jLong") long arrLen, @Cast("const bool") boolean useUnsigned);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer lShapeInfo, @Cast("const uint*") IntBuffer uShapeInfo, @Cast("Nd4jLong") long arrLen, @Cast("const bool") boolean useUnsigned);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] lShapeInfo, @Cast("const uint*") int[] uShapeInfo, @Cast("Nd4jLong") long arrLen, @Cast("const bool") boolean useUnsigned);
|
||||
|
||||
/**
|
||||
* Compute the real linear indices for the given shape and stride
|
||||
*/
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer computeIndices(int rank, @Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer computeIndices(int rank, @Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong*") long[] computeIndices(int rank, @Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] stride);
|
||||
|
||||
/**
|
||||
* Compute the real linear indices for the
|
||||
* given shape buffer. Shape,stride and rank are derived
|
||||
* from the buffer
|
||||
*/
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer computeIndices( @Cast("Nd4jLong*") LongPointer shapeBuffer);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer computeIndices( @Cast("Nd4jLong*") LongBuffer shapeBuffer);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong*") long[] computeIndices( @Cast("Nd4jLong*") long[] shapeBuffer);
|
||||
@Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") IntPointer shapeInfo);
|
||||
@Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") IntBuffer shapeInfo);
|
||||
@Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") int[] shapeInfo);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer lShapeInfo, @Cast("const uint*") IntPointer uShapeInfo, @Cast("const bool") boolean useUnsigned);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer lShapeInfo, @Cast("const uint*") IntBuffer uShapeInfo, @Cast("const bool") boolean useUnsigned);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] lShapeInfo, @Cast("const uint*") int[] uShapeInfo, @Cast("const bool") boolean useUnsigned);
|
||||
|
||||
@Namespace("shape") public static native void printShapeInfo(@Cast("Nd4jLong*") LongPointer shapeInfo);
|
||||
@Namespace("shape") public static native void printShapeInfo(@Cast("Nd4jLong*") LongBuffer shapeInfo);
|
||||
|
@ -8328,20 +8299,62 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
* for the given rank and shape.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Compute the real linear indices for the given shape and stride
|
||||
*/
|
||||
|
||||
/**
|
||||
* Compute the real linear indices for the given shape and stride
|
||||
*/
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
// //////////////////////////////////////////////////////////////////////
|
||||
// INLINEDEF _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong arrLen) {
|
||||
|
||||
// const Nd4jLong ews = shapeInfo[shapeInfo[0] + shapeInfo[0] + 2];
|
||||
|
||||
// if(ews > 0 && order(shapeInfo) == 'c')
|
||||
// if (ews == 1)
|
||||
// return index;
|
||||
// else
|
||||
// return ews * index;
|
||||
|
||||
// Nd4jLong offset = 0;
|
||||
// Nd4jLong rank = shapeInfo[0];
|
||||
// for(int i = 1; i <= shapeInfo[0]; ++i) {
|
||||
// arrLen /= shapeInfo[i];
|
||||
// if(arrLen > 0 && shapeInfo[i] > 1) {
|
||||
// offset += (index / arrLen) * shapeInfo[i + rank];
|
||||
// index %= arrLen;
|
||||
// }
|
||||
// }
|
||||
// return offset;
|
||||
// }
|
||||
|
||||
// INLINEDEF _CUDA_HD uint getIndexOffset(uint index, const uint *shapeInfo, uint arrLen) {
|
||||
|
||||
// const uint rank = shapeInfo[0];
|
||||
// const uint ews = shapeInfo[rank + rank + 2];
|
||||
|
||||
// if(ews > 0 && shapeInfo[rank + rank + 3] == 99)
|
||||
// if (ews == 1)
|
||||
// return index;
|
||||
// else
|
||||
// return ews * index;
|
||||
|
||||
// uint offset = 0;
|
||||
|
||||
// for(uint i = 1; i <= rank; ++i) {
|
||||
// arrLen /= shapeInfo[i];
|
||||
// if(arrLen > 0 && shapeInfo[i] > 1) {
|
||||
// offset += (index / arrLen) * shapeInfo[i + rank];
|
||||
// index %= arrLen;
|
||||
// }
|
||||
// }
|
||||
// return offset;
|
||||
// }
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
|
@ -8710,6 +8723,10 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
* @return the double at the specified index
|
||||
*/
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -9038,6 +9055,8 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -13850,6 +13869,31 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
}
|
||||
// #endif
|
||||
|
||||
/**
|
||||
* This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes:
|
||||
* 1) if shapes are equal that's pairwise operation, result will have the same shape.
|
||||
* 2) if shape X is scalar and shape Y is array - result will have shape equal to Y.
|
||||
* 3) if shape X is array and shape Y is scalar - result will have shape equal to X.
|
||||
* 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result.
|
||||
*
|
||||
* This operation returns Z = Divide(X, Y) with exception, 0 if Y = 0
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_divide_no_nan)
|
||||
@Namespace("nd4j::ops") public static class divide_no_nan extends BroadcastableOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public divide_no_nan(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public divide_no_nan(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public divide_no_nan position(long position) {
|
||||
return (divide_no_nan)super.position(position);
|
||||
}
|
||||
|
||||
public divide_no_nan() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
}
|
||||
// #endif
|
||||
/**
|
||||
* This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes:
|
||||
* 1) if shapes are equal that's pairwise operation, result will have the same shape.
|
||||
|
@ -14386,6 +14430,54 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
}
|
||||
// #endif
|
||||
|
||||
/**
|
||||
* Broadcastable igamma implementation
|
||||
*
|
||||
* igamma(a, x) = gamma(а, x) / Gamma(a) - Gamma distribution function P(a,x)
|
||||
* Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt }
|
||||
* gamma(a, x) = int from 0 to x { t ^ {a - 1} e^{-t}dt }
|
||||
* \tparam T
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_igamma)
|
||||
@Namespace("nd4j::ops") public static class igamma extends BroadcastableOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public igamma(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public igamma(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public igamma position(long position) {
|
||||
return (igamma)super.position(position);
|
||||
}
|
||||
|
||||
public igamma() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
}
|
||||
// #endif
|
||||
/**
|
||||
* Broadcastable igammac implementation
|
||||
* igammac(a, x) = Gamma(a,x)/Gamma(а) - Gamma distribution function Q(a,x)
|
||||
* Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt }
|
||||
* Gamma(a, x) = int from x to infinity { t ^ {a - 1} e^{-t}dt }
|
||||
* \tparam T
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_igammac)
|
||||
@Namespace("nd4j::ops") public static class igammac extends BroadcastableOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public igammac(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public igammac(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public igammac position(long position) {
|
||||
return (igammac)super.position(position);
|
||||
}
|
||||
|
||||
public igammac() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
}
|
||||
// #endif
|
||||
|
||||
|
||||
|
||||
// #endif
|
||||
|
@ -15842,6 +15934,26 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
}
|
||||
// #endif
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// #if NOT_EXCLUDED(OP_lstmLayer)
|
||||
@Namespace("nd4j::ops") public static class lstmLayer extends DeclarableCustomOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public lstmLayer(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public lstmLayer(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public lstmLayer position(long position) {
|
||||
return (lstmLayer)super.position(position);
|
||||
}
|
||||
|
||||
public lstmLayer() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
// #endif
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
/**
|
||||
* Implementation of operations for Simple Recurrent Unit cell: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi
|
||||
|
@ -17079,16 +17191,16 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
* Inserts elements provided by diagonal array into the main diagonal of innermost matrices of input array
|
||||
*
|
||||
* Input arrays:
|
||||
* input: input array, considered as batch of matrices
|
||||
* diagonal: array containing elements to be inserted into input array,
|
||||
* following rank condition should be satisfied: diagonal_rank = input_rank - 1,
|
||||
* the shapes of diagonal and input arrays must be equal except last dimension of input array,
|
||||
* for example if input_shape = [A,B,C,D] then diagonal_shape = [A,B,C],
|
||||
* also last dimension of diagonal array should be equal to smaller of last and last but one input dimensions
|
||||
* that is: diagonal_shape[-1] = min(input_shape[-1], input_shape[-2])
|
||||
* 0: input array, considered as batch of matrices
|
||||
* 1: diagonal array containing elements to be inserted into input array,
|
||||
* following rank condition should be satisfied: diagonal_rank = input_rank - 1,
|
||||
* the shapes of diagonal and input arrays must be equal except last dimension of input array,
|
||||
* for example if input_shape = [A,B,C,D] then diagonal_shape = [A,B,C],
|
||||
* also last dimension of diagonal array should be equal to smaller of last and last but one input dimensions
|
||||
* that is: diagonal_shape[-1] = min(input_shape[-1], input_shape[-2])
|
||||
*
|
||||
* Output array:
|
||||
* has the same shape as input, corresponding diagonal elements are substituted
|
||||
* 0: has the same shape as input, corresponding diagonal elements are substituted
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_matrix_set_diag)
|
||||
@Namespace("nd4j::ops") public static class matrix_set_diag extends DeclarableOp {
|
||||
|
@ -17109,8 +17221,16 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
// #endif
|
||||
|
||||
/**
|
||||
* Returns a batched matrix tensor with diagonal values given (as TF.matrix_diag).
|
||||
*/
|
||||
* Inserts elements provided by diagonal array into the main diagonal of innermost matrices of output array,
|
||||
* rest output elements are set to zeros
|
||||
*
|
||||
* Input array:
|
||||
* diagonal: array containing elements to be inserted into output array,
|
||||
* following rank condition is present: diagonal_rank = ouput_rank - 1
|
||||
*
|
||||
* Output array:
|
||||
* 0: is considered as batch of matrices, if for example diagonal array has shape [A,B,C] then output array has shape [A,B,C,C]
|
||||
*/
|
||||
@Namespace("nd4j::ops") public static class matrix_diag extends DeclarableCustomOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
|
@ -17130,13 +17250,13 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
/**
|
||||
* This op calculates regularized incomplete beta integral Ix(a, b).
|
||||
* Implementation is based on two algorithms depending on input values of a and b:
|
||||
* - when a and b are both > maxValue (3000.), then apply Gauss-Legendre quadrature method
|
||||
* - when a and b are both <= maxValue (3000.), then apply modified Lentz’s algorithm for continued fractions
|
||||
* - when a and b are both > maxValue (3000.), then Gauss-Legendre quadrature method is applied
|
||||
* - when a and b are both <= maxValue (3000.), then modified Lentz’s algorithm for continued fractions is applied
|
||||
*
|
||||
* Input arrays:
|
||||
* a: define power t^{a-1}, must be > 0, type float.
|
||||
* b: define power (1-t)^{b-1}, must be > 0, type float.
|
||||
* x: define upper limit of integration, must be within (0 <= x <= 1) range, type float.
|
||||
* a: defines power t^{a-1}, must be > 0, type float.
|
||||
* b: defines power (1-t)^{b-1}, must be > 0, type float.
|
||||
* x: defines upper limit of integration, must be within (0 <= x <= 1) range, type float.
|
||||
*
|
||||
* Output array:
|
||||
* 0: values of regularized incomplete beta integral that corresponds to variable upper limit x, type float
|
||||
|
@ -18250,6 +18370,50 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
}
|
||||
// #endif
|
||||
|
||||
/**
|
||||
* This operation adjusts image contrast by given factor ( z = (x - mean) * factor + mean )
|
||||
* Input arrays:
|
||||
* 0 - input array with rank >= 3, must have last one dimension equal 3, that is dimension containing channels.
|
||||
*
|
||||
* T arguments:
|
||||
* 0 - contrast factor
|
||||
*
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_adjust_contrast)
|
||||
@Namespace("nd4j::ops") public static class adjust_contrast extends DeclarableOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public adjust_contrast(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public adjust_contrast(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public adjust_contrast position(long position) {
|
||||
return (adjust_contrast)super.position(position);
|
||||
}
|
||||
|
||||
public adjust_contrast() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
@Namespace("nd4j::ops") public static class adjust_contrast_v2 extends DeclarableOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public adjust_contrast_v2(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public adjust_contrast_v2(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public adjust_contrast_v2 position(long position) {
|
||||
return (adjust_contrast_v2)super.position(position);
|
||||
}
|
||||
|
||||
public adjust_contrast_v2() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
// #endif
|
||||
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* This operation rearranges data from depth into blocks of spatial data. This is the reverse transformation
|
||||
|
@ -19634,6 +19798,37 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
}
|
||||
// #endif
|
||||
|
||||
/**
|
||||
* draw_bounding_boxes op - modified input image with given colors exept given boxes.
|
||||
*
|
||||
* input params:
|
||||
* 0 - images tensor (4D) with shape {batch, width, height, channels}, where channes is 1 (BW image),
|
||||
* 3 (RGB) or 4 (RGBA)
|
||||
* 1 - boxes tensor (3D) with shape {batch, number_of_boxes, 4} where last dimension encoded as
|
||||
* (y_min, x_min, y_max, x_max), all values in between 0. and 1.
|
||||
* 2 - colours tensor (2D) with shape {number_of_boxes, channels} -- bordering color set (palette)
|
||||
*
|
||||
* output:
|
||||
* 0 - 4D tensor with same shape as images (input 0)
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_draw_bounding_boxes)
|
||||
@Namespace("nd4j::ops") public static class draw_bounding_boxes extends DeclarableOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public draw_bounding_boxes(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public draw_bounding_boxes(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public draw_bounding_boxes position(long position) {
|
||||
return (draw_bounding_boxes)super.position(position);
|
||||
}
|
||||
|
||||
public draw_bounding_boxes() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
// #endif
|
||||
|
||||
/**
|
||||
* roll - op porting from numpy (https://docs.scipy.org/doc/numpy-1.14.0/reference/generated/numpy.roll.html)
|
||||
*
|
||||
|
@ -20623,6 +20818,67 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
}
|
||||
// #endif
|
||||
|
||||
/**
|
||||
* fake_quant_with_min_max_vals_per_channel - tf.quantization.fake_quant_with_min_max_vars_per_channel
|
||||
*
|
||||
* input params:
|
||||
* 0 - NDArray (input) - at least 2D.
|
||||
* 1 - 1D Tensor - min values (min length equals to last dim of input)
|
||||
* 2 - 1D Tensor - max value (length equals to min)
|
||||
*
|
||||
* int params (optional):
|
||||
* 0 - num_bits (allowed interval [2, 16], default 8)
|
||||
* 1 - narrow_range (default False)
|
||||
*
|
||||
* output:
|
||||
* 0 - NDArray with the same shape as input
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_fake_quant_with_min_max_vars_per_channel)
|
||||
@Namespace("nd4j::ops") public static class fake_quant_with_min_max_vars_per_channel extends DeclarableOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public fake_quant_with_min_max_vars_per_channel(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public fake_quant_with_min_max_vars_per_channel(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public fake_quant_with_min_max_vars_per_channel position(long position) {
|
||||
return (fake_quant_with_min_max_vars_per_channel)super.position(position);
|
||||
}
|
||||
|
||||
public fake_quant_with_min_max_vars_per_channel() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
// #endif
|
||||
|
||||
/**
|
||||
* compare_and_bitpack - compare with greater and pack result with uint8
|
||||
*
|
||||
* input params:
|
||||
* 0 - NDArray (input)
|
||||
* 1 - 0D Tensor - threshold
|
||||
*
|
||||
*
|
||||
* output:
|
||||
* 0 - NDArray with the same shape as input and type uint8
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_compare_and_bitpack)
|
||||
@Namespace("nd4j::ops") public static class compare_and_bitpack extends DeclarableCustomOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public compare_and_bitpack(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public compare_and_bitpack(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public compare_and_bitpack position(long position) {
|
||||
return (compare_and_bitpack)super.position(position);
|
||||
}
|
||||
|
||||
public compare_and_bitpack() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
// #endif
|
||||
|
||||
|
||||
|
||||
|
@ -23102,6 +23358,28 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
// #endif
|
||||
/**
|
||||
* This operation change type of input and modified shape of output to conform with given data type
|
||||
*
|
||||
* all as above op
|
||||
* */
|
||||
// #if NOT_EXCLUDED(OP_bitcast)
|
||||
@Namespace("nd4j::ops") public static class bitcast extends DeclarableCustomOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public bitcast(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public bitcast(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public bitcast position(long position) {
|
||||
return (bitcast)super.position(position);
|
||||
}
|
||||
|
||||
public bitcast() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
// #endif
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -810,7 +810,7 @@ public class CustomOpsTests extends BaseNd4jTest {
|
|||
@Test
|
||||
public void testAdjustContrast() {
|
||||
INDArray in = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 4*4*3).reshape(4,4,3);
|
||||
INDArray out = Nd4j.zeros(4,4,3);
|
||||
INDArray out = Nd4j.zeros(DataType.DOUBLE,4, 4, 3);
|
||||
|
||||
INDArray expected = Nd4j.createFromArray(new double[]{-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5,
|
||||
2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5,
|
||||
|
@ -920,4 +920,15 @@ public class CustomOpsTests extends BaseNd4jTest {
|
|||
|
||||
assertEquals(expected, output);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testKnnMinDistance() {
|
||||
INDArray point = Nd4j.rand(DataType.FLOAT, 1, 20);
|
||||
INDArray lowest = Nd4j.rand(DataType.FLOAT, 1, 20);
|
||||
INDArray highest = Nd4j.rand(DataType.FLOAT, 1, 20);
|
||||
INDArray distance = Nd4j.scalar(0.f);
|
||||
|
||||
Nd4j.exec(new KnnMinDistance(point, lowest, highest, distance));
|
||||
System.out.println(distance);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue