[WIP] SVD (#16)

* - new SVD constructor
- OrthogonalDistribution now uses SVD custom op

Signed-off-by: raver119 <raver119@gmail.com>

* shapes fixed

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-10-28 12:31:01 +03:00 committed by GitHub
parent 029a69a835
commit 5a4d2e8b31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 69 additions and 49 deletions

View File

@ -47,6 +47,20 @@ public class Svd extends DynamicCustomOp {
public Svd(){ } public Svd(){ }
public Svd(INDArray input, boolean full_matrices, INDArray s, INDArray u, INDArray v) {
inputArguments.add(input);
fullUV = full_matrices;
computeUv = true;
switchNum = DEFAULT_SWITCHNUM;
outputArguments.add(s);
outputArguments.add(u);
outputArguments.add(v);
addIArgument(ArrayUtil.fromBoolean(fullUV), ArrayUtil.fromBoolean(computeUv), switchNum);
}
public Svd(SameDiff sd, SDVariable input, boolean fullUV, boolean computeUv){ public Svd(SameDiff sd, SDVariable input, boolean fullUV, boolean computeUv){
this(sd, input, fullUV, computeUv, DEFAULT_SWITCHNUM); this(sd, input, fullUV, computeUv, DEFAULT_SWITCHNUM);
} }

View File

@ -21,6 +21,7 @@ import lombok.val;
import org.apache.commons.math3.exception.NumberIsTooLargeException; import org.apache.commons.math3.exception.NumberIsTooLargeException;
import org.apache.commons.math3.exception.OutOfRangeException; import org.apache.commons.math3.exception.OutOfRangeException;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Svd;
import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution; import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
import org.nd4j.linalg.api.rng.distribution.BaseDistribution; import org.nd4j.linalg.api.rng.distribution.BaseDistribution;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -231,21 +232,20 @@ public class OrthogonalDistribution extends BaseDistribution {
val flatShape = new long[]{numRows, numCols}; val flatShape = new long[]{numRows, numCols};
val flatRng = Nd4j.getExecutioner().exec(new GaussianDistribution(Nd4j.createUninitialized(dtype, flatShape, Nd4j.order()), 0.0, 1.0), random); val flatRng = Nd4j.getExecutioner().exec(new GaussianDistribution(Nd4j.createUninitialized(dtype, flatShape, Nd4j.order()), 0.0, 1.0), random);
long m = flatRng.rows(); val m = flatRng.rows();
long n = flatRng.columns(); val n = flatRng.columns();
val s = Nd4j.create(dtype, m < n ? m : n); val s = Nd4j.create(dtype, m < n ? m : n);
val u = m < n ? Nd4j.create(dtype, m, n) : Nd4j.create(dtype, m, m); val u = Nd4j.create(dtype, m, m);
val v = Nd4j.create(dtype, new long[] {n, n}, 'f'); val v = Nd4j.create(dtype, new long[] {n, n}, 'f');
Nd4j.getBlasWrapper().lapack().gesvd(flatRng, s, u, v); Nd4j.exec(new Svd(flatRng, true, s, u, v));
// FIXME: int cast
if (gains == null) { if (gains == null) {
if (u.rows() == numRows && u.columns() == numCols) { if (u.rows() >= numRows && u.columns() >= numCols) {
return v.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain).reshape(shape);
} else {
return u.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain).reshape(shape); return u.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain).reshape(shape);
} else {
return v.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain).reshape(shape);
} }
} else { } else {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();

View File

@ -598,6 +598,10 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
public native @Cast("bool") boolean isCPU(); public native @Cast("bool") boolean isCPU();
public native int blasMajorVersion();
public native int blasMinorVersion();
public native int blasPatchVersion();
public native @StdVector Pair capabilities(); public native @StdVector Pair capabilities();
} }
@ -12281,6 +12285,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
// #include <ops/declarable/headers/datatypes.h> // #include <ops/declarable/headers/datatypes.h>
// #include <ops/declarable/headers/third_party.h> // #include <ops/declarable/headers/third_party.h>
// #include <ops/declarable/headers/tests.h> // #include <ops/declarable/headers/tests.h>
// #include <ops/declarable/headers/kernels.h>
// #include <ops/declarable/headers/BarnesHutTsne.h> // #include <ops/declarable/headers/BarnesHutTsne.h>
// #include <dll.h> // #include <dll.h>
// #include <helpers/shape.h> // #include <helpers/shape.h>
@ -21398,12 +21403,12 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
private native void allocate(); private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
} }
// #endif // #endif
/** /**
* Local response normalization implementation as TF. * Local response normalization implementation as TF.
* input: 4D array * input: 4D array
* *
* T args: * T args:
* *
* 0: bias * 0: bias
@ -21411,8 +21416,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* 2: beta * 2: beta
* *
* Int arg: depth - optional local radius * Int arg: depth - optional local radius
* *
* output - 4D array * output - 4D array
*/ */
// #if NOT_EXCLUDED(OP_lrn) // #if NOT_EXCLUDED(OP_lrn)
@Namespace("nd4j::ops") public static class lrn extends DeclarableOp { @Namespace("nd4j::ops") public static class lrn extends DeclarableOp {
@ -21434,10 +21439,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
/** /**
* Local response normalization - backprop variant. * Local response normalization - backprop variant.
* input: * input:
* 0 - 4D array of data * 0 - 4D array of data
* 1 - epsilon - 4D array of approximation * 1 - epsilon - 4D array of approximation
* *
* T args: * T args:
* *
* 0: bias * 0: bias
@ -21467,21 +21472,21 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
// #endif // #endif
/** /**
* Batch normalization implementation. * Batch normalization implementation.
* Reference: https://arxiv.org/abs/1502.03167v3 * Reference: https://arxiv.org/abs/1502.03167v3
* *
* Expected arguments: * Expected arguments:
* input: input array (any number of dimensions) * input: input array (any number of dimensions)
* mean: * mean:
* variance: * variance:
* gamma: * gamma:
* beta: * beta:
* *
* Int args: * Int args:
* 0: apply scale * 0: apply scale
* 1: apply offset * 1: apply offset
* *
* *
* T args: * T args:
* 0: epsilon * 0: epsilon
*/ */
@ -21502,27 +21507,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
} }
// #endif // #endif
// #if NOT_EXCLUDED(OP_batchnorm_new)
@Namespace("nd4j::ops") public static class batchnorm_new extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public batchnorm_new(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public batchnorm_new(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public batchnorm_new position(long position) {
return (batchnorm_new)super.position(position);
}
public batchnorm_new() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
/** /**
* back prop in batch normalization * back prop in batch normalization
* *
* Expected arguments: * Expected arguments:
* input: input array (any number of dimensions) * input: input array (any number of dimensions)
* mean: * mean:
@ -21530,11 +21518,11 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* gamma: optional * gamma: optional
* beta: optional * beta: optional
* dLdOut: next epsilon * dLdOut: next epsilon
* *
* Int args: * Int args:
* 0: apply scale * 0: apply scale
* 1: apply offset * 1: apply offset
* *
* T args: * T args:
* 0: epsilon * 0: epsilon
* *
@ -21542,8 +21530,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* dL/dInput * dL/dInput
* dL/dMean * dL/dMean
* dL/dVariance * dL/dVariance
* dL/dGamma * dL/dGamma, optional
* dL/dBeta * dL/dBeta, optional
*/ */
// #if NOT_EXCLUDED(OP_batchnorm) // #if NOT_EXCLUDED(OP_batchnorm)
@Namespace("nd4j::ops") public static class batchnorm_bp extends DeclarableCustomOp { @Namespace("nd4j::ops") public static class batchnorm_bp extends DeclarableCustomOp {
@ -21570,7 +21558,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* x: parameters, any shape * x: parameters, any shape
* y: gradients. same shape as x * y: gradients. same shape as x
* lr: optional, learning rate * lr: optional, learning rate
* *
* T args: * T args:
* 0: optional, learning rate * 0: optional, learning rate
*/ */
@ -21589,25 +21577,25 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
public apply_sgd() { super((Pointer)null); allocate(); } public apply_sgd() { super((Pointer)null); allocate(); }
private native void allocate(); private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
} }
// #endif // #endif
/** /**
* This operation performs batch normalization of layer, it is based on following article http://arxiv.org/abs/1502.03167. * This operation performs batch normalization of layer, it is based on following article http://arxiv.org/abs/1502.03167.
* Expected arguments: * Expected arguments:
* x: input 4D array of shape [bS,iH,iW,iD] (data format = NHWC) or [bS,iD,iH,iW] (data format = NCHW), where * x: input 4D array of shape [bS,iH,iW,iD] (data format = NHWC) or [bS,iD,iH,iW] (data format = NCHW), where
* bS - batch size * bS - batch size
* iH - input height * iH - input height
* iW - input width * iW - input width
* iD - input depth (or number of channels) * iD - input depth (or number of channels)
* scale: 1D input array of scale factors, shape [iD] * scale: 1D input array of scale factors, shape [iD]
* offset: 1D input array of offsets (shifts), shape [iD] * offset: 1D input array of offsets (shifts), shape [iD]
* mean: 1D input array of population mean used for inference, shape [iD], this array is required only if isTraining = false * mean: 1D input array of population mean used for inference, shape [iD], this array is required only if isTraining = false
* variance: 1D input array of population mean used for inference, shape [iD], this array is required only if isTraining = false * variance: 1D input array of population mean used for inference, shape [iD], this array is required only if isTraining = false
* *
* T input arguments: * T input arguments:
* 0: epsilon, it is optional argument, default value is 0.001, this is small number to be added to the variance of x * 0: epsilon, it is optional argument, default value is 0.001, this is small number to be added to the variance of x
* *
* integer input arguments: * integer input arguments:
* 0: dataFormat, may have two values: zero -> NHWC, unity -> NCHW * 0: dataFormat, may have two values: zero -> NHWC, unity -> NCHW
* 1: isTraining, may have two values: zero -> inference, unity -> training * 1: isTraining, may have two values: zero -> inference, unity -> training

View File

@ -1375,6 +1375,24 @@ public class RandomTests extends BaseNd4jTest {
log.info("Array: {}", array); log.info("Array: {}", array);
} }
@Test
public void testOrthogonalDistribution2() {
val dist = new OrthogonalDistribution(1.0);
val array = dist.sample(new int[] {9, 6});
log.info("Array: {}", array);
}
@Test
public void testOrthogonalDistribution3() {
val dist = new OrthogonalDistribution(1.0);
val array = dist.sample(new int[] {9, 9});
log.info("Array: {}", array);
}
@Test @Test
public void reproducabilityTest(){ public void reproducabilityTest(){