[WIP] SVD (#16)
* - new SVD constructor - OrthogonalDistribution now uses SVD custom op Signed-off-by: raver119 <raver119@gmail.com> * shapes fixed Signed-off-by: raver119 <raver119@gmail.com>master
parent
029a69a835
commit
5a4d2e8b31
|
@ -47,6 +47,20 @@ public class Svd extends DynamicCustomOp {
|
||||||
|
|
||||||
public Svd(){ }
|
public Svd(){ }
|
||||||
|
|
||||||
|
public Svd(INDArray input, boolean full_matrices, INDArray s, INDArray u, INDArray v) {
|
||||||
|
inputArguments.add(input);
|
||||||
|
fullUV = full_matrices;
|
||||||
|
computeUv = true;
|
||||||
|
switchNum = DEFAULT_SWITCHNUM;
|
||||||
|
|
||||||
|
|
||||||
|
outputArguments.add(s);
|
||||||
|
outputArguments.add(u);
|
||||||
|
outputArguments.add(v);
|
||||||
|
|
||||||
|
addIArgument(ArrayUtil.fromBoolean(fullUV), ArrayUtil.fromBoolean(computeUv), switchNum);
|
||||||
|
}
|
||||||
|
|
||||||
public Svd(SameDiff sd, SDVariable input, boolean fullUV, boolean computeUv){
|
public Svd(SameDiff sd, SDVariable input, boolean fullUV, boolean computeUv){
|
||||||
this(sd, input, fullUV, computeUv, DEFAULT_SWITCHNUM);
|
this(sd, input, fullUV, computeUv, DEFAULT_SWITCHNUM);
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,6 +21,7 @@ import lombok.val;
|
||||||
import org.apache.commons.math3.exception.NumberIsTooLargeException;
|
import org.apache.commons.math3.exception.NumberIsTooLargeException;
|
||||||
import org.apache.commons.math3.exception.OutOfRangeException;
|
import org.apache.commons.math3.exception.OutOfRangeException;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.Svd;
|
||||||
import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
|
import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
|
||||||
import org.nd4j.linalg.api.rng.distribution.BaseDistribution;
|
import org.nd4j.linalg.api.rng.distribution.BaseDistribution;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -231,21 +232,20 @@ public class OrthogonalDistribution extends BaseDistribution {
|
||||||
val flatShape = new long[]{numRows, numCols};
|
val flatShape = new long[]{numRows, numCols};
|
||||||
val flatRng = Nd4j.getExecutioner().exec(new GaussianDistribution(Nd4j.createUninitialized(dtype, flatShape, Nd4j.order()), 0.0, 1.0), random);
|
val flatRng = Nd4j.getExecutioner().exec(new GaussianDistribution(Nd4j.createUninitialized(dtype, flatShape, Nd4j.order()), 0.0, 1.0), random);
|
||||||
|
|
||||||
long m = flatRng.rows();
|
val m = flatRng.rows();
|
||||||
long n = flatRng.columns();
|
val n = flatRng.columns();
|
||||||
|
|
||||||
val s = Nd4j.create(dtype, m < n ? m : n);
|
val s = Nd4j.create(dtype, m < n ? m : n);
|
||||||
val u = m < n ? Nd4j.create(dtype, m, n) : Nd4j.create(dtype, m, m);
|
val u = Nd4j.create(dtype, m, m);
|
||||||
val v = Nd4j.create(dtype, new long[] {n, n}, 'f');
|
val v = Nd4j.create(dtype, new long[] {n, n}, 'f');
|
||||||
|
|
||||||
Nd4j.getBlasWrapper().lapack().gesvd(flatRng, s, u, v);
|
Nd4j.exec(new Svd(flatRng, true, s, u, v));
|
||||||
|
|
||||||
// FIXME: int cast
|
|
||||||
if (gains == null) {
|
if (gains == null) {
|
||||||
if (u.rows() == numRows && u.columns() == numCols) {
|
if (u.rows() >= numRows && u.columns() >= numCols) {
|
||||||
return v.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain).reshape(shape);
|
|
||||||
} else {
|
|
||||||
return u.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain).reshape(shape);
|
return u.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain).reshape(shape);
|
||||||
|
} else {
|
||||||
|
return v.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain).reshape(shape);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
|
|
|
@ -598,6 +598,10 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
|
||||||
|
|
||||||
public native @Cast("bool") boolean isCPU();
|
public native @Cast("bool") boolean isCPU();
|
||||||
|
|
||||||
|
public native int blasMajorVersion();
|
||||||
|
public native int blasMinorVersion();
|
||||||
|
public native int blasPatchVersion();
|
||||||
|
|
||||||
public native @StdVector Pair capabilities();
|
public native @StdVector Pair capabilities();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -12281,6 +12285,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
// #include <ops/declarable/headers/datatypes.h>
|
// #include <ops/declarable/headers/datatypes.h>
|
||||||
// #include <ops/declarable/headers/third_party.h>
|
// #include <ops/declarable/headers/third_party.h>
|
||||||
// #include <ops/declarable/headers/tests.h>
|
// #include <ops/declarable/headers/tests.h>
|
||||||
|
// #include <ops/declarable/headers/kernels.h>
|
||||||
// #include <ops/declarable/headers/BarnesHutTsne.h>
|
// #include <ops/declarable/headers/BarnesHutTsne.h>
|
||||||
// #include <dll.h>
|
// #include <dll.h>
|
||||||
// #include <helpers/shape.h>
|
// #include <helpers/shape.h>
|
||||||
|
@ -21398,12 +21403,12 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
private native void allocate();
|
private native void allocate();
|
||||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||||
}
|
}
|
||||||
// #endif
|
// #endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Local response normalization implementation as TF.
|
* Local response normalization implementation as TF.
|
||||||
* input: 4D array
|
* input: 4D array
|
||||||
*
|
*
|
||||||
* T args:
|
* T args:
|
||||||
*
|
*
|
||||||
* 0: bias
|
* 0: bias
|
||||||
|
@ -21411,8 +21416,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
* 2: beta
|
* 2: beta
|
||||||
*
|
*
|
||||||
* Int arg: depth - optional local radius
|
* Int arg: depth - optional local radius
|
||||||
*
|
*
|
||||||
* output - 4D array
|
* output - 4D array
|
||||||
*/
|
*/
|
||||||
// #if NOT_EXCLUDED(OP_lrn)
|
// #if NOT_EXCLUDED(OP_lrn)
|
||||||
@Namespace("nd4j::ops") public static class lrn extends DeclarableOp {
|
@Namespace("nd4j::ops") public static class lrn extends DeclarableOp {
|
||||||
|
@ -21434,10 +21439,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Local response normalization - backprop variant.
|
* Local response normalization - backprop variant.
|
||||||
* input:
|
* input:
|
||||||
* 0 - 4D array of data
|
* 0 - 4D array of data
|
||||||
* 1 - epsilon - 4D array of approximation
|
* 1 - epsilon - 4D array of approximation
|
||||||
*
|
*
|
||||||
* T args:
|
* T args:
|
||||||
*
|
*
|
||||||
* 0: bias
|
* 0: bias
|
||||||
|
@ -21467,21 +21472,21 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
// #endif
|
// #endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Batch normalization implementation.
|
* Batch normalization implementation.
|
||||||
* Reference: https://arxiv.org/abs/1502.03167v3
|
* Reference: https://arxiv.org/abs/1502.03167v3
|
||||||
*
|
*
|
||||||
* Expected arguments:
|
* Expected arguments:
|
||||||
* input: input array (any number of dimensions)
|
* input: input array (any number of dimensions)
|
||||||
* mean:
|
* mean:
|
||||||
* variance:
|
* variance:
|
||||||
* gamma:
|
* gamma:
|
||||||
* beta:
|
* beta:
|
||||||
*
|
*
|
||||||
* Int args:
|
* Int args:
|
||||||
* 0: apply scale
|
* 0: apply scale
|
||||||
* 1: apply offset
|
* 1: apply offset
|
||||||
*
|
*
|
||||||
*
|
*
|
||||||
* T args:
|
* T args:
|
||||||
* 0: epsilon
|
* 0: epsilon
|
||||||
*/
|
*/
|
||||||
|
@ -21502,27 +21507,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||||
}
|
}
|
||||||
// #endif
|
// #endif
|
||||||
// #if NOT_EXCLUDED(OP_batchnorm_new)
|
|
||||||
@Namespace("nd4j::ops") public static class batchnorm_new extends DeclarableCustomOp {
|
|
||||||
static { Loader.load(); }
|
|
||||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
|
||||||
public batchnorm_new(Pointer p) { super(p); }
|
|
||||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
|
||||||
public batchnorm_new(long size) { super((Pointer)null); allocateArray(size); }
|
|
||||||
private native void allocateArray(long size);
|
|
||||||
@Override public batchnorm_new position(long position) {
|
|
||||||
return (batchnorm_new)super.position(position);
|
|
||||||
}
|
|
||||||
|
|
||||||
public batchnorm_new() { super((Pointer)null); allocate(); }
|
|
||||||
private native void allocate();
|
|
||||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
|
||||||
}
|
|
||||||
// #endif
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* back prop in batch normalization
|
* back prop in batch normalization
|
||||||
*
|
*
|
||||||
* Expected arguments:
|
* Expected arguments:
|
||||||
* input: input array (any number of dimensions)
|
* input: input array (any number of dimensions)
|
||||||
* mean:
|
* mean:
|
||||||
|
@ -21530,11 +21518,11 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
* gamma: optional
|
* gamma: optional
|
||||||
* beta: optional
|
* beta: optional
|
||||||
* dLdOut: next epsilon
|
* dLdOut: next epsilon
|
||||||
*
|
*
|
||||||
* Int args:
|
* Int args:
|
||||||
* 0: apply scale
|
* 0: apply scale
|
||||||
* 1: apply offset
|
* 1: apply offset
|
||||||
*
|
*
|
||||||
* T args:
|
* T args:
|
||||||
* 0: epsilon
|
* 0: epsilon
|
||||||
*
|
*
|
||||||
|
@ -21542,8 +21530,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
* dL/dInput
|
* dL/dInput
|
||||||
* dL/dMean
|
* dL/dMean
|
||||||
* dL/dVariance
|
* dL/dVariance
|
||||||
* dL/dGamma
|
* dL/dGamma, optional
|
||||||
* dL/dBeta
|
* dL/dBeta, optional
|
||||||
*/
|
*/
|
||||||
// #if NOT_EXCLUDED(OP_batchnorm)
|
// #if NOT_EXCLUDED(OP_batchnorm)
|
||||||
@Namespace("nd4j::ops") public static class batchnorm_bp extends DeclarableCustomOp {
|
@Namespace("nd4j::ops") public static class batchnorm_bp extends DeclarableCustomOp {
|
||||||
|
@ -21570,7 +21558,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
* x: parameters, any shape
|
* x: parameters, any shape
|
||||||
* y: gradients. same shape as x
|
* y: gradients. same shape as x
|
||||||
* lr: optional, learning rate
|
* lr: optional, learning rate
|
||||||
*
|
*
|
||||||
* T args:
|
* T args:
|
||||||
* 0: optional, learning rate
|
* 0: optional, learning rate
|
||||||
*/
|
*/
|
||||||
|
@ -21589,25 +21577,25 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
public apply_sgd() { super((Pointer)null); allocate(); }
|
public apply_sgd() { super((Pointer)null); allocate(); }
|
||||||
private native void allocate();
|
private native void allocate();
|
||||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||||
}
|
}
|
||||||
// #endif
|
// #endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This operation performs batch normalization of layer, it is based on following article http://arxiv.org/abs/1502.03167.
|
* This operation performs batch normalization of layer, it is based on following article http://arxiv.org/abs/1502.03167.
|
||||||
* Expected arguments:
|
* Expected arguments:
|
||||||
* x: input 4D array of shape [bS,iH,iW,iD] (data format = NHWC) or [bS,iD,iH,iW] (data format = NCHW), where
|
* x: input 4D array of shape [bS,iH,iW,iD] (data format = NHWC) or [bS,iD,iH,iW] (data format = NCHW), where
|
||||||
* bS - batch size
|
* bS - batch size
|
||||||
* iH - input height
|
* iH - input height
|
||||||
* iW - input width
|
* iW - input width
|
||||||
* iD - input depth (or number of channels)
|
* iD - input depth (or number of channels)
|
||||||
* scale: 1D input array of scale factors, shape [iD]
|
* scale: 1D input array of scale factors, shape [iD]
|
||||||
* offset: 1D input array of offsets (shifts), shape [iD]
|
* offset: 1D input array of offsets (shifts), shape [iD]
|
||||||
* mean: 1D input array of population mean used for inference, shape [iD], this array is required only if isTraining = false
|
* mean: 1D input array of population mean used for inference, shape [iD], this array is required only if isTraining = false
|
||||||
* variance: 1D input array of population mean used for inference, shape [iD], this array is required only if isTraining = false
|
* variance: 1D input array of population mean used for inference, shape [iD], this array is required only if isTraining = false
|
||||||
*
|
*
|
||||||
* T input arguments:
|
* T input arguments:
|
||||||
* 0: epsilon, it is optional argument, default value is 0.001, this is small number to be added to the variance of x
|
* 0: epsilon, it is optional argument, default value is 0.001, this is small number to be added to the variance of x
|
||||||
*
|
*
|
||||||
* integer input arguments:
|
* integer input arguments:
|
||||||
* 0: dataFormat, may have two values: zero -> NHWC, unity -> NCHW
|
* 0: dataFormat, may have two values: zero -> NHWC, unity -> NCHW
|
||||||
* 1: isTraining, may have two values: zero -> inference, unity -> training
|
* 1: isTraining, may have two values: zero -> inference, unity -> training
|
||||||
|
|
|
@ -1375,6 +1375,24 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
log.info("Array: {}", array);
|
log.info("Array: {}", array);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testOrthogonalDistribution2() {
|
||||||
|
val dist = new OrthogonalDistribution(1.0);
|
||||||
|
|
||||||
|
val array = dist.sample(new int[] {9, 6});
|
||||||
|
|
||||||
|
log.info("Array: {}", array);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testOrthogonalDistribution3() {
|
||||||
|
val dist = new OrthogonalDistribution(1.0);
|
||||||
|
|
||||||
|
val array = dist.sample(new int[] {9, 9});
|
||||||
|
|
||||||
|
log.info("Array: {}", array);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void reproducabilityTest(){
|
public void reproducabilityTest(){
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue