From 5a4d2e8b317bf5866da85c0fd2e49248bbf48620 Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 28 Oct 2019 12:31:01 +0300 Subject: [PATCH] [WIP] SVD (#16) * - new SVD constructor - OrthogonalDistribution now uses SVD custom op Signed-off-by: raver119 * shapes fixed Signed-off-by: raver119 --- .../api/ops/impl/transforms/custom/Svd.java | 14 ++++ .../impl/OrthogonalDistribution.java | 16 ++--- .../java/org/nd4j/nativeblas/Nd4jCpu.java | 70 ++++++++----------- .../java/org/nd4j/linalg/rng/RandomTests.java | 18 +++++ 4 files changed, 69 insertions(+), 49 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Svd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Svd.java index 186205df7..21ceaed83 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Svd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Svd.java @@ -47,6 +47,20 @@ public class Svd extends DynamicCustomOp { public Svd(){ } + public Svd(INDArray input, boolean full_matrices, INDArray s, INDArray u, INDArray v) { + inputArguments.add(input); + fullUV = full_matrices; + computeUv = true; + switchNum = DEFAULT_SWITCHNUM; + + + outputArguments.add(s); + outputArguments.add(u); + outputArguments.add(v); + + addIArgument(ArrayUtil.fromBoolean(fullUV), ArrayUtil.fromBoolean(computeUv), switchNum); + } + public Svd(SameDiff sd, SDVariable input, boolean fullUV, boolean computeUv){ this(sd, input, fullUV, computeUv, DEFAULT_SWITCHNUM); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/OrthogonalDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/OrthogonalDistribution.java index 4452bc7e4..31d24f11e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/OrthogonalDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/OrthogonalDistribution.java @@ -21,6 +21,7 @@ import lombok.val; import org.apache.commons.math3.exception.NumberIsTooLargeException; import org.apache.commons.math3.exception.OutOfRangeException; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.custom.Svd; import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution; import org.nd4j.linalg.api.rng.distribution.BaseDistribution; import org.nd4j.linalg.factory.Nd4j; @@ -231,21 +232,20 @@ public class OrthogonalDistribution extends BaseDistribution { val flatShape = new long[]{numRows, numCols}; val flatRng = Nd4j.getExecutioner().exec(new GaussianDistribution(Nd4j.createUninitialized(dtype, flatShape, Nd4j.order()), 0.0, 1.0), random); - long m = flatRng.rows(); - long n = flatRng.columns(); + val m = flatRng.rows(); + val n = flatRng.columns(); val s = Nd4j.create(dtype, m < n ? m : n); - val u = m < n ? Nd4j.create(dtype, m, n) : Nd4j.create(dtype, m, m); + val u = Nd4j.create(dtype, m, m); val v = Nd4j.create(dtype, new long[] {n, n}, 'f'); - Nd4j.getBlasWrapper().lapack().gesvd(flatRng, s, u, v); + Nd4j.exec(new Svd(flatRng, true, s, u, v)); - // FIXME: int cast if (gains == null) { - if (u.rows() == numRows && u.columns() == numCols) { - return v.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain).reshape(shape); - } else { + if (u.rows() >= numRows && u.columns() >= numCols) { return u.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain).reshape(shape); + } else { + return v.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain).reshape(shape); } } else { throw new UnsupportedOperationException(); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 047cb4021..c84846a5f 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -598,6 +598,10 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper { public native @Cast("bool") boolean isCPU(); + public native int blasMajorVersion(); + public native int blasMinorVersion(); + public native int blasPatchVersion(); + public native @StdVector Pair capabilities(); } @@ -12281,6 +12285,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // #include // #include // #include +// #include // #include // #include // #include @@ -21398,12 +21403,12 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif /** * Local response normalization implementation as TF. * input: 4D array - * + * * T args: * * 0: bias @@ -21411,8 +21416,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * 2: beta * * Int arg: depth - optional local radius - * - * output - 4D array + * + * output - 4D array */ // #if NOT_EXCLUDED(OP_lrn) @Namespace("nd4j::ops") public static class lrn extends DeclarableOp { @@ -21434,10 +21439,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); /** * Local response normalization - backprop variant. - * input: + * input: * 0 - 4D array of data * 1 - epsilon - 4D array of approximation - * + * * T args: * * 0: bias @@ -21467,21 +21472,21 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // #endif /** - * Batch normalization implementation. + * Batch normalization implementation. * Reference: https://arxiv.org/abs/1502.03167v3 - * + * * Expected arguments: * input: input array (any number of dimensions) * mean: * variance: * gamma: * beta: - * + * * Int args: * 0: apply scale * 1: apply offset - * - * + * + * * T args: * 0: epsilon */ @@ -21502,27 +21507,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } // #endif -// #if NOT_EXCLUDED(OP_batchnorm_new) - @Namespace("nd4j::ops") public static class batchnorm_new extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public batchnorm_new(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public batchnorm_new(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public batchnorm_new position(long position) { - return (batchnorm_new)super.position(position); - } - - public batchnorm_new() { super((Pointer)null); allocate(); } - private native void allocate(); - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } -// #endif /** * back prop in batch normalization - * + * * Expected arguments: * input: input array (any number of dimensions) * mean: @@ -21530,11 +21518,11 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * gamma: optional * beta: optional * dLdOut: next epsilon - * + * * Int args: * 0: apply scale - * 1: apply offset - * + * 1: apply offset + * * T args: * 0: epsilon * @@ -21542,8 +21530,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * dL/dInput * dL/dMean * dL/dVariance - * dL/dGamma - * dL/dBeta + * dL/dGamma, optional + * dL/dBeta, optional */ // #if NOT_EXCLUDED(OP_batchnorm) @Namespace("nd4j::ops") public static class batchnorm_bp extends DeclarableCustomOp { @@ -21570,7 +21558,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * x: parameters, any shape * y: gradients. same shape as x * lr: optional, learning rate - * + * * T args: * 0: optional, learning rate */ @@ -21589,25 +21577,25 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public apply_sgd() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } + } // #endif /** * This operation performs batch normalization of layer, it is based on following article http://arxiv.org/abs/1502.03167. * Expected arguments: * x: input 4D array of shape [bS,iH,iW,iD] (data format = NHWC) or [bS,iD,iH,iW] (data format = NCHW), where - * bS - batch size - * iH - input height - * iW - input width + * bS - batch size + * iH - input height + * iW - input width * iD - input depth (or number of channels) * scale: 1D input array of scale factors, shape [iD] * offset: 1D input array of offsets (shifts), shape [iD] * mean: 1D input array of population mean used for inference, shape [iD], this array is required only if isTraining = false * variance: 1D input array of population mean used for inference, shape [iD], this array is required only if isTraining = false - * + * * T input arguments: * 0: epsilon, it is optional argument, default value is 0.001, this is small number to be added to the variance of x - * + * * integer input arguments: * 0: dataFormat, may have two values: zero -> NHWC, unity -> NCHW * 1: isTraining, may have two values: zero -> inference, unity -> training diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java index f95842cc2..ed8f4d441 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java @@ -1375,6 +1375,24 @@ public class RandomTests extends BaseNd4jTest { log.info("Array: {}", array); } + @Test + public void testOrthogonalDistribution2() { + val dist = new OrthogonalDistribution(1.0); + + val array = dist.sample(new int[] {9, 6}); + + log.info("Array: {}", array); + } + + @Test + public void testOrthogonalDistribution3() { + val dist = new OrthogonalDistribution(1.0); + + val array = dist.sample(new int[] {9, 9}); + + log.info("Array: {}", array); + } + @Test public void reproducabilityTest(){