[WIP] Fixed signatures. SameDiff tests (#258)
* Fixed signatures. SameDiff tests Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Tests fixed Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Test fixed Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Small fix Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Fixed test Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>master
parent
a66e03355e
commit
bc2a7dd7ae
|
@ -2836,13 +2836,13 @@ public class Nd4j {
|
||||||
* @param columns the number of columns in the matrix
|
* @param columns the number of columns in the matrix
|
||||||
* @return the random ndarray with the specified shape
|
* @return the random ndarray with the specified shape
|
||||||
*/
|
*/
|
||||||
public static INDArray rand(int rows, int columns) {
|
/*public static INDArray rand(int rows, int columns) {
|
||||||
if (rows < 1 || columns < 1)
|
if (rows < 1 || columns < 1)
|
||||||
throw new ND4JIllegalStateException("Number of rows and columns should be positive for new INDArray");
|
throw new ND4JIllegalStateException("Number of rows and columns should be positive for new INDArray");
|
||||||
|
|
||||||
INDArray ret = createUninitialized(new int[] {rows, columns}, Nd4j.order());
|
INDArray ret = createUninitialized(new int[] {rows, columns}, Nd4j.order());
|
||||||
return rand(ret);
|
return rand(ret);
|
||||||
}
|
}*/
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a random ndarray with the given shape and output order
|
* Create a random ndarray with the given shape and output order
|
||||||
|
@ -2853,13 +2853,13 @@ public class Nd4j {
|
||||||
* @param columns the number of columns in the matrix
|
* @param columns the number of columns in the matrix
|
||||||
* @return the random ndarray with the specified shape
|
* @return the random ndarray with the specified shape
|
||||||
*/
|
*/
|
||||||
public static INDArray rand(char order, int rows, int columns) {
|
/*public static INDArray rand(char order, int rows, int columns) {
|
||||||
if (rows < 1 || columns < 1)
|
if (rows < 1 || columns < 1)
|
||||||
throw new ND4JIllegalStateException("Number of rows and columns should be positive for new INDArray");
|
throw new ND4JIllegalStateException("Number of rows and columns should be positive for new INDArray");
|
||||||
|
|
||||||
INDArray ret = createUninitialized(new int[] {rows, columns}, order);//INSTANCE.rand(order, rows, columns);
|
INDArray ret = createUninitialized(new int[] {rows, columns}, order);//INSTANCE.rand(order, rows, columns);
|
||||||
return rand(ret);
|
return rand(ret);
|
||||||
}
|
}*/
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a random ndarray with values from a uniform distribution over (0, 1) with the given shape
|
* Create a random ndarray with values from a uniform distribution over (0, 1) with the given shape
|
||||||
|
@ -2892,10 +2892,10 @@ public class Nd4j {
|
||||||
* @param seed the seed to use
|
* @param seed the seed to use
|
||||||
* @return the random ndarray with the specified shape
|
* @return the random ndarray with the specified shape
|
||||||
*/
|
*/
|
||||||
public static INDArray rand(int rows, int columns, long seed) {
|
/*public static INDArray rand(int rows, int columns, long seed) {
|
||||||
INDArray ret = createUninitialized(new int[] {rows, columns}, Nd4j.order());
|
INDArray ret = createUninitialized(new int[] {rows, columns}, Nd4j.order());
|
||||||
return rand(ret, seed);
|
return rand(ret, seed);
|
||||||
}
|
}*/
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @deprecated use {@link Nd4j#rand(org.nd4j.linalg.api.rng.Random, long...)}
|
* @deprecated use {@link Nd4j#rand(org.nd4j.linalg.api.rng.Random, long...)}
|
||||||
|
@ -2999,10 +2999,10 @@ public class Nd4j {
|
||||||
* @param rng the rng to use
|
* @param rng the rng to use
|
||||||
* @return a drandom matrix of the specified shape and range
|
* @return a drandom matrix of the specified shape and range
|
||||||
*/
|
*/
|
||||||
public static INDArray rand(int rows, int columns, double min, double max, @NonNull org.nd4j.linalg.api.rng.Random rng) {
|
/*public static INDArray rand(int rows, int columns, double min, double max, @NonNull org.nd4j.linalg.api.rng.Random rng) {
|
||||||
INDArray ret = createUninitialized(rows, columns);
|
INDArray ret = createUninitialized(rows, columns);
|
||||||
return rand(ret, min, max, rng);
|
return rand(ret, min, max, rng);
|
||||||
}
|
}*/
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Fill the given ndarray with random numbers drawn from a normal distribution
|
* Fill the given ndarray with random numbers drawn from a normal distribution
|
||||||
|
@ -3020,7 +3020,7 @@ public class Nd4j {
|
||||||
* @param shape the shape of the array
|
* @param shape the shape of the array
|
||||||
* @return new array with random values
|
* @return new array with random values
|
||||||
*/
|
*/
|
||||||
public static INDArray randn(@NonNull int... shape) {
|
public static INDArray randn(@NonNull int[] shape) {
|
||||||
return randn(ArrayUtil.toLongArray(shape));
|
return randn(ArrayUtil.toLongArray(shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3031,7 +3031,7 @@ public class Nd4j {
|
||||||
* @param shape the shape of the ndarray
|
* @param shape the shape of the ndarray
|
||||||
* @return new array with random values
|
* @return new array with random values
|
||||||
*/
|
*/
|
||||||
public static INDArray randn(@NonNull DataType dataType, @NonNull int... shape) {
|
public static INDArray randn(@NonNull DataType dataType, @NonNull int[] shape) {
|
||||||
return randn(dataType, ArrayUtil.toLongArray(shape));
|
return randn(dataType, ArrayUtil.toLongArray(shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3135,10 +3135,10 @@ public class Nd4j {
|
||||||
* @param rows the number of rows in the matrix
|
* @param rows the number of rows in the matrix
|
||||||
* @param columns the number of columns in the matrix
|
* @param columns the number of columns in the matrix
|
||||||
*/
|
*/
|
||||||
public static INDArray randn(char order, long rows, long columns) {
|
/*public static INDArray randn(char order, long rows, long columns) {
|
||||||
INDArray ret = Nd4j.createUninitialized(new long[]{rows, columns}, order);
|
INDArray ret = Nd4j.createUninitialized(new long[]{rows, columns}, order);
|
||||||
return randn(ret);
|
return randn(ret);
|
||||||
}
|
}*/
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Random normal using the specified seed
|
* Random normal using the specified seed
|
||||||
|
@ -3147,10 +3147,10 @@ public class Nd4j {
|
||||||
* @param columns the number of columns in the matrix
|
* @param columns the number of columns in the matrix
|
||||||
* @return new array with random values
|
* @return new array with random values
|
||||||
*/
|
*/
|
||||||
public static INDArray randn(long rows, long columns, long seed) {
|
/*public static INDArray randn(long rows, long columns, long seed) {
|
||||||
INDArray ret = Nd4j.createUninitialized(new long[]{rows, columns}, order());
|
INDArray ret = Nd4j.createUninitialized(new long[]{rows, columns}, order());
|
||||||
return randn(ret, seed);
|
return randn(ret, seed);
|
||||||
}
|
}*/
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Random normal using the given rng
|
* Random normal using the given rng
|
||||||
|
@ -3160,10 +3160,10 @@ public class Nd4j {
|
||||||
* @param r the random generator to use
|
* @param r the random generator to use
|
||||||
* @return new array with random values
|
* @return new array with random values
|
||||||
*/
|
*/
|
||||||
public static INDArray randn(long rows, long columns, @NonNull org.nd4j.linalg.api.rng.Random r) {
|
/*public static INDArray randn(long rows, long columns, @NonNull org.nd4j.linalg.api.rng.Random r) {
|
||||||
INDArray ret = Nd4j.createUninitialized(new long[]{rows, columns}, order());
|
INDArray ret = Nd4j.createUninitialized(new long[]{rows, columns}, order());
|
||||||
return randn(ret, r);
|
return randn(ret, r);
|
||||||
}
|
}*/
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @deprecated use {@link Nd4j#randn(org.nd4j.linalg.api.rng.Random, long...)}
|
* @deprecated use {@link Nd4j#randn(org.nd4j.linalg.api.rng.Random, long...)}
|
||||||
|
@ -3193,6 +3193,14 @@ public class Nd4j {
|
||||||
return randn(ret, r);
|
return randn(ret, r);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static INDArray randn(double mean, double stddev, INDArray target, @NonNull org.nd4j.linalg.api.rng.Random rng) {
|
||||||
|
return getExecutioner().exec(new GaussianDistribution(target, mean, stddev), rng);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static INDArray randn(double mean, double stddev, long[] shape, @NonNull org.nd4j.linalg.api.rng.Random rng) {
|
||||||
|
INDArray target = Nd4j.createUninitialized(shape);
|
||||||
|
return getExecutioner().exec(new GaussianDistribution(target, mean, stddev), rng);
|
||||||
|
}
|
||||||
/**
|
/**
|
||||||
* Fill the given ndarray with random numbers drawn from a uniform distribution
|
* Fill the given ndarray with random numbers drawn from a uniform distribution
|
||||||
*
|
*
|
||||||
|
@ -3361,9 +3369,9 @@ public class Nd4j {
|
||||||
* @param columns columns
|
* @param columns columns
|
||||||
* @return uninitialized 2D array of rows x columns
|
* @return uninitialized 2D array of rows x columns
|
||||||
*/
|
*/
|
||||||
public static INDArray createUninitialized(long rows, long columns) {
|
/*public static INDArray createUninitialized(long rows, long columns) {
|
||||||
return createUninitialized(new long[] {rows, columns});
|
return createUninitialized(new long[] {rows, columns});
|
||||||
}
|
}*/
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a row vector with the data
|
* Creates a row vector with the data
|
||||||
|
@ -3740,7 +3748,7 @@ public class Nd4j {
|
||||||
* @param shape the shape of the array
|
* @param shape the shape of the array
|
||||||
* @return the created ndarray
|
* @return the created ndarray
|
||||||
*/
|
*/
|
||||||
public static INDArray create(float[] data, int... shape) {
|
public static INDArray create(float[] data, int[] shape) {
|
||||||
if (shape.length == 0 && data.length == 1) {
|
if (shape.length == 0 && data.length == 1) {
|
||||||
return scalar(data[0]);
|
return scalar(data[0]);
|
||||||
}
|
}
|
||||||
|
@ -3782,7 +3790,7 @@ public class Nd4j {
|
||||||
* @param shape the shape of the array
|
* @param shape the shape of the array
|
||||||
* @return the created ndarray
|
* @return the created ndarray
|
||||||
*/
|
*/
|
||||||
public static INDArray create(double[] data, int... shape) {
|
public static INDArray create(double[] data, int[] shape) {
|
||||||
commonCheckCreate(data.length, LongUtils.toLongs(shape));
|
commonCheckCreate(data.length, LongUtils.toLongs(shape));
|
||||||
val lshape = ArrayUtil.toLongArray(shape);
|
val lshape = ArrayUtil.toLongArray(shape);
|
||||||
return INSTANCE.create(data, lshape, Nd4j.getStrides(lshape, Nd4j.order()), DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace());
|
return INSTANCE.create(data, lshape, Nd4j.getStrides(lshape, Nd4j.order()), DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||||
|
|
|
@ -23,6 +23,7 @@ import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import lombok.val;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.autodiff.listeners.impl.ScoreListener;
|
import org.nd4j.autodiff.listeners.impl.ScoreListener;
|
||||||
import org.nd4j.autodiff.listeners.records.History;
|
import org.nd4j.autodiff.listeners.records.History;
|
||||||
|
@ -30,8 +31,11 @@ import org.nd4j.evaluation.IEvaluation;
|
||||||
import org.nd4j.evaluation.classification.Evaluation;
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.IrisDataSetIterator;
|
import org.nd4j.linalg.dataset.IrisDataSetIterator;
|
||||||
|
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||||
|
import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
|
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -290,6 +294,54 @@ public class SameDiffTrainingTest extends BaseNd4jTest {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void simpleClassification() {
|
||||||
|
double learning_rate = 0.001;
|
||||||
|
int seed = 7;
|
||||||
|
org.nd4j.linalg.api.rng.Random rng = Nd4j.getRandom();
|
||||||
|
rng.setSeed(seed);
|
||||||
|
INDArray x1_label1 = Nd4j.randn(3.0, 1.0, new long[]{1000}, rng);
|
||||||
|
INDArray x2_label1 = Nd4j.randn(2.0, 1.0, new long[]{1000}, rng);
|
||||||
|
INDArray x1_label2 = Nd4j.randn(7.0, 1.0, new long[]{1000}, rng);
|
||||||
|
INDArray x2_label2 = Nd4j.randn(6.0, 1.0, new long[]{1000}, rng);
|
||||||
|
|
||||||
|
INDArray x1s = Nd4j.concat(0, x1_label1, x1_label2);
|
||||||
|
INDArray x2s = Nd4j.concat(0, x2_label1, x2_label2);
|
||||||
|
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
INDArray ys = Nd4j.scalar(0.0).mul(x1_label1.length()).add(Nd4j.scalar(1.0).mul(x1_label2.length()));
|
||||||
|
|
||||||
|
SDVariable X1 = sd.placeHolder("x1", DataType.DOUBLE, 2000);
|
||||||
|
SDVariable X2 = sd.placeHolder("x2", DataType.DOUBLE, 2000);
|
||||||
|
SDVariable y = sd.placeHolder("y", DataType.DOUBLE);
|
||||||
|
SDVariable w = sd.var("w", DataType.DOUBLE, 3);
|
||||||
|
|
||||||
|
// TF code:
|
||||||
|
//cost = tf.reduce_mean(-tf.log(y_model * Y + (1 — y_model) * (1 — Y)))
|
||||||
|
SDVariable y_model =
|
||||||
|
sd.nn.sigmoid(w.get(SDIndex.point(2)).mul(X2).add(w.get(SDIndex.point(1)).mul(X1)).add(w.get(SDIndex.point(0))));
|
||||||
|
SDVariable cost_fun =
|
||||||
|
(sd.math.neg(sd.math.log(y_model.mul(y).add((sd.math.log(sd.constant(1.0).minus(y_model)).mul(sd.constant(1.0).minus(y)))))));
|
||||||
|
SDVariable loss = sd.mean("loss", cost_fun);
|
||||||
|
|
||||||
|
val updater = new Sgd(learning_rate);
|
||||||
|
|
||||||
|
sd.setLossVariables("loss");
|
||||||
|
sd.createGradFunction();
|
||||||
|
val conf = new TrainingConfig.Builder()
|
||||||
|
.updater(updater)
|
||||||
|
.minimize("loss")
|
||||||
|
.dataSetFeatureMapping("x1", "x2", "y")
|
||||||
|
.markLabelsUnused()
|
||||||
|
.build();
|
||||||
|
|
||||||
|
MultiDataSet mds = new MultiDataSet(new INDArray[]{x1s, x2s, ys},null);
|
||||||
|
|
||||||
|
sd.setTrainingConfig(conf);
|
||||||
|
History history = sd.fit(new SingletonMultiDataSetIterator(mds), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public char ordering() {
|
public char ordering() {
|
||||||
return 'c';
|
return 'c';
|
||||||
|
|
|
@ -458,7 +458,7 @@ public class CompressionTests extends BaseNd4jTest {
|
||||||
@Test
|
@Test
|
||||||
public void testBitmapEncoding4() {
|
public void testBitmapEncoding4() {
|
||||||
Nd4j.getRandom().setSeed(119);
|
Nd4j.getRandom().setSeed(119);
|
||||||
INDArray initial = Nd4j.rand(1, 10000, 0, 1, Nd4j.getRandom());
|
INDArray initial = Nd4j.rand(new int[]{1, 10000}, 0, 1, Nd4j.getRandom());
|
||||||
INDArray exp_1 = initial.dup();
|
INDArray exp_1 = initial.dup();
|
||||||
|
|
||||||
INDArray enc = Nd4j.getExecutioner().bitmapEncode(initial, 1e-1);
|
INDArray enc = Nd4j.getExecutioner().bitmapEncode(initial, 1e-1);
|
||||||
|
@ -471,7 +471,7 @@ public class CompressionTests extends BaseNd4jTest {
|
||||||
@Test
|
@Test
|
||||||
public void testBitmapEncoding5() {
|
public void testBitmapEncoding5() {
|
||||||
Nd4j.getRandom().setSeed(119);
|
Nd4j.getRandom().setSeed(119);
|
||||||
INDArray initial = Nd4j.rand(1, 10000, -1, -0.5, Nd4j.getRandom());
|
INDArray initial = Nd4j.rand(new int[]{1, 10000}, -1, -0.5, Nd4j.getRandom());
|
||||||
INDArray exp_0 = initial.dup().addi(1e-1);
|
INDArray exp_0 = initial.dup().addi(1e-1);
|
||||||
INDArray exp_1 = initial.dup();
|
INDArray exp_1 = initial.dup();
|
||||||
|
|
||||||
|
@ -486,7 +486,7 @@ public class CompressionTests extends BaseNd4jTest {
|
||||||
@Test
|
@Test
|
||||||
public void testBitmapEncoding6() {
|
public void testBitmapEncoding6() {
|
||||||
Nd4j.getRandom().setSeed(119);
|
Nd4j.getRandom().setSeed(119);
|
||||||
INDArray initial = Nd4j.rand(1, 100000, -1, 1, Nd4j.getRandom());
|
INDArray initial = Nd4j.rand(new int[]{1, 100000}, -1, 1, Nd4j.getRandom());
|
||||||
INDArray exp_1 = initial.dup();
|
INDArray exp_1 = initial.dup();
|
||||||
|
|
||||||
INDArray enc = Nd4j.getExecutioner().bitmapEncode(initial, 1e-3);
|
INDArray enc = Nd4j.getExecutioner().bitmapEncode(initial, 1e-3);
|
||||||
|
|
|
@ -186,12 +186,12 @@ public class NormalizerStandardizeLabelsTest extends BaseNd4jTest {
|
||||||
int i = 0;
|
int i = 0;
|
||||||
// Randomly generate scaling constants and add offsets
|
// Randomly generate scaling constants and add offsets
|
||||||
// to get aA and bB
|
// to get aA and bB
|
||||||
INDArray aA = a == 1 ? Nd4j.ones(1, nFeatures) : Nd4j.rand(1, nFeatures, randSeed).mul(a); //a = 1, don't scale
|
INDArray aA = a == 1 ? Nd4j.ones(1, nFeatures) : Nd4j.rand(new int[]{1, nFeatures}, randSeed).mul(a); //a = 1, don't scale
|
||||||
INDArray bB = Nd4j.rand(1, nFeatures, randSeed).mul(b); //b = 0 this zeros out
|
INDArray bB = Nd4j.rand(new int[]{1, nFeatures}, randSeed).mul(b); //b = 0 this zeros out
|
||||||
// transform ndarray as X = aA + bB * X
|
// transform ndarray as X = aA + bB * X
|
||||||
INDArray randomFeatures = Nd4j.zeros(nSamples, nFeatures);
|
INDArray randomFeatures = Nd4j.zeros(nSamples, nFeatures);
|
||||||
while (i < nFeatures) {
|
while (i < nFeatures) {
|
||||||
INDArray randomSlice = Nd4j.randn(nSamples, 1, randSeed);
|
INDArray randomSlice = Nd4j.randn(new int[]{nSamples, 1}, randSeed);
|
||||||
randomSlice.muli(aA.getScalar(0, i));
|
randomSlice.muli(aA.getScalar(0, i));
|
||||||
randomSlice.addi(bB.getScalar(0, i));
|
randomSlice.addi(bB.getScalar(0, i));
|
||||||
randomFeatures.putColumn(i, randomSlice);
|
randomFeatures.putColumn(i, randomSlice);
|
||||||
|
|
|
@ -303,13 +303,13 @@ public class NormalizerStandardizeTest extends BaseNd4jTest {
|
||||||
int i = 0;
|
int i = 0;
|
||||||
// Randomly generate scaling constants and add offsets
|
// Randomly generate scaling constants and add offsets
|
||||||
// to get aA and bB
|
// to get aA and bB
|
||||||
INDArray aA = a == 1 ? Nd4j.ones(1, nFeatures) : Nd4j.rand(1, nFeatures, randSeed).mul(a); //a = 1, don't scale
|
INDArray aA = a == 1 ? Nd4j.ones(1, nFeatures) : Nd4j.rand(new int[]{1, nFeatures}, randSeed).mul(a); //a = 1, don't scale
|
||||||
INDArray bB = Nd4j.rand(1, nFeatures, randSeed).mul(b); //b = 0 this zeros out
|
INDArray bB = Nd4j.rand(new int[]{1, nFeatures}, randSeed).mul(b); //b = 0 this zeros out
|
||||||
// transform ndarray as X = aA + bB * X
|
// transform ndarray as X = aA + bB * X
|
||||||
INDArray randomFeatures = Nd4j.zeros(nSamples, nFeatures);
|
INDArray randomFeatures = Nd4j.zeros(nSamples, nFeatures);
|
||||||
INDArray randomFeaturesTransform = Nd4j.zeros(nSamples, nFeatures);
|
INDArray randomFeaturesTransform = Nd4j.zeros(nSamples, nFeatures);
|
||||||
while (i < nFeatures) {
|
while (i < nFeatures) {
|
||||||
INDArray randomSlice = Nd4j.randn(nSamples, 1, randSeed);
|
INDArray randomSlice = Nd4j.randn(new int[]{nSamples, 1}, randSeed);
|
||||||
randomFeaturesTransform.putColumn(i, randomSlice);
|
randomFeaturesTransform.putColumn(i, randomSlice);
|
||||||
randomSlice.muli(aA.getScalar(0, i));
|
randomSlice.muli(aA.getScalar(0, i));
|
||||||
randomSlice.addi(bB.getScalar(0, i));
|
randomSlice.addi(bB.getScalar(0, i));
|
||||||
|
|
|
@ -1048,7 +1048,7 @@ public class OpExecutionerTestsC extends BaseNd4jTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNorm2_2() {
|
public void testNorm2_2() {
|
||||||
INDArray array = Nd4j.rand(127, 164, 1, 100, Nd4j.getRandom());
|
INDArray array = Nd4j.rand(new int[]{127, 164}, 1, 100, Nd4j.getRandom());
|
||||||
|
|
||||||
double norm2 = array.norm2Number().doubleValue();
|
double norm2 = array.norm2Number().doubleValue();
|
||||||
}
|
}
|
||||||
|
|
|
@ -906,8 +906,8 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
public void testSignatures1() {
|
public void testSignatures1() {
|
||||||
|
|
||||||
for (int x = 0; x < 100; x++) {
|
for (int x = 0; x < 100; x++) {
|
||||||
INDArray z1 = Nd4j.randn(128, 1, 5325235);
|
INDArray z1 = Nd4j.randn(new int[]{128, 1}, 5325235);
|
||||||
INDArray z2 = Nd4j.randn(128, 1, 5325235);
|
INDArray z2 = Nd4j.randn(new int[]{128, 1}, 5325235);
|
||||||
|
|
||||||
assertEquals(z1, z2);
|
assertEquals(z1, z2);
|
||||||
}
|
}
|
||||||
|
|
|
@ -50,7 +50,7 @@ object Implicits {
|
||||||
Nd4j.create(underlying, shape, ord.value)
|
Nd4j.create(underlying, shape, ord.value)
|
||||||
|
|
||||||
def asNDArray(shape: Int*): INDArray =
|
def asNDArray(shape: Int*): INDArray =
|
||||||
Nd4j.create(underlying.toArray, shape.toArray: _*)
|
Nd4j.create(underlying.toArray, shape.toArray)
|
||||||
|
|
||||||
def toNDArray: INDArray = Nd4j.create(underlying)
|
def toNDArray: INDArray = Nd4j.create(underlying)
|
||||||
}
|
}
|
||||||
|
@ -66,7 +66,7 @@ object Implicits {
|
||||||
Nd4j.create(underlying, shape, offset, ord.value)
|
Nd4j.create(underlying, shape, offset, ord.value)
|
||||||
|
|
||||||
def asNDArray(shape: Int*): INDArray =
|
def asNDArray(shape: Int*): INDArray =
|
||||||
Nd4j.create(underlying.toArray, shape.toArray: _*)
|
Nd4j.create(underlying.toArray, shape.toArray)
|
||||||
|
|
||||||
def toNDArray: INDArray = Nd4j.create(underlying)
|
def toNDArray: INDArray = Nd4j.create(underlying)
|
||||||
}
|
}
|
||||||
|
|
|
@ -110,11 +110,11 @@ trait NDArrayEvidence[NDArray <: INDArray, Value] {
|
||||||
|
|
||||||
def put(a: NDArray, i: Array[Int], element: INDArray): NDArray
|
def put(a: NDArray, i: Array[Int], element: INDArray): NDArray
|
||||||
|
|
||||||
def get(a: NDArray, i: Int): Value
|
def get(a: NDArray, i: Long): Value
|
||||||
|
|
||||||
def get(a: NDArray, i: Int, j: Int): Value
|
//def get(a: NDArray, i: Long, j: Long): Value
|
||||||
|
|
||||||
def get(a: NDArray, i: Int*): Value
|
def get(a: NDArray, i: Long*): Value
|
||||||
|
|
||||||
def get(a: NDArray, i: INDArrayIndex*): NDArray
|
def get(a: NDArray, i: INDArrayIndex*): NDArray
|
||||||
|
|
||||||
|
@ -254,11 +254,15 @@ case object DoubleNDArrayEvidence extends RealNDArrayEvidence[Double] {
|
||||||
override def norm1(ndarray: INDArray): Double =
|
override def norm1(ndarray: INDArray): Double =
|
||||||
ndarray.norm1Number().doubleValue()
|
ndarray.norm1Number().doubleValue()
|
||||||
|
|
||||||
override def get(a: INDArray, i: Int): Double = a.getDouble(i.toLong)
|
override def get(a: INDArray, i: Long): Double = a.getDouble(i)
|
||||||
|
|
||||||
override def get(a: INDArray, i: Int, j: Int): Double = a.getDouble(i.toLong, j.toLong)
|
//override def get(a: INDArray, i: Int): Double = a.getDouble(i.toLong)
|
||||||
|
|
||||||
override def get(a: INDArray, i: Int*): Double = a.getDouble(i: _*)
|
//override def get(a: INDArray, i: Int, j: Int): Double = a.getDouble(i.toLong, j.toLong)
|
||||||
|
|
||||||
|
//override def get(a: INDArray, i: Int*): Double = a.getDouble(i: _*)
|
||||||
|
|
||||||
|
override def get(a: INDArray, i: Long*): Double = a.getDouble(i: _*)
|
||||||
|
|
||||||
override def create(arr: Array[Double]): INDArray = arr.toNDArray
|
override def create(arr: Array[Double]): INDArray = arr.toNDArray
|
||||||
|
|
||||||
|
@ -315,11 +319,13 @@ case object FloatNDArrayEvidence extends RealNDArrayEvidence[Float] {
|
||||||
override def norm1(ndarray: INDArray): Float =
|
override def norm1(ndarray: INDArray): Float =
|
||||||
ndarray.norm1Number().floatValue()
|
ndarray.norm1Number().floatValue()
|
||||||
|
|
||||||
override def get(a: INDArray, i: Int): Float = a.getFloat(i)
|
override def get(a: INDArray, i: Long): Float = a.getFloat(i)
|
||||||
|
|
||||||
override def get(a: INDArray, i: Int, j: Int): Float = a.getFloat(i, j)
|
//override def get(a: INDArray, i: Long, j: Long): Float = a.getFloat(i, j)
|
||||||
|
|
||||||
override def get(a: INDArray, i: Int*): Float = a.getFloat(i.toArray)
|
//override def get(a: INDArray, i: Int*): Float = a.getFloat(i: _*)
|
||||||
|
|
||||||
|
override def get(a: INDArray, i: Long*): Float = a.getFloat(i: _*)
|
||||||
|
|
||||||
override def create(arr: Array[Float]): INDArray = arr.toNDArray
|
override def create(arr: Array[Float]): INDArray = arr.toNDArray
|
||||||
|
|
||||||
|
@ -440,11 +446,14 @@ case object IntNDArrayEvidence extends IntegerNDArrayEvidence[Int] {
|
||||||
|
|
||||||
def variance(ndarray: INDArray): Int = ndarray.varNumber().intValue()
|
def variance(ndarray: INDArray): Int = ndarray.varNumber().intValue()
|
||||||
|
|
||||||
|
def get(a: INDArray, i: Long): Int = a.getInt(i.toInt)
|
||||||
|
|
||||||
def get(a: INDArray, i: Int): Int = a.getInt(i)
|
def get(a: INDArray, i: Int): Int = a.getInt(i)
|
||||||
|
|
||||||
def get(a: INDArray, i: Int, j: Int): Int = a.getInt(i, j)
|
def get(a: INDArray, i: Int, j: Int): Int = a.getInt(i, j)
|
||||||
|
|
||||||
def get(a: INDArray, i: Int*): Int = a.getInt(i: _*)
|
def get(a: INDArray, i: Long*): Int =
|
||||||
|
a.getInt(i.map(_.toInt): _*)
|
||||||
|
|
||||||
def create(arr: Array[Int]): INDArray = arr.toNDArray
|
def create(arr: Array[Int]): INDArray = arr.toNDArray
|
||||||
|
|
||||||
|
@ -480,11 +489,13 @@ case object LongNDArrayEvidence extends IntegerNDArrayEvidence[Long] {
|
||||||
|
|
||||||
def variance(ndarray: INDArray): Long = ndarray.varNumber().longValue()
|
def variance(ndarray: INDArray): Long = ndarray.varNumber().longValue()
|
||||||
|
|
||||||
|
def get(a: INDArray, i: Long): Long = a.getLong(i)
|
||||||
|
|
||||||
def get(a: INDArray, i: Int): Long = a.getLong(i)
|
def get(a: INDArray, i: Int): Long = a.getLong(i)
|
||||||
|
|
||||||
def get(a: INDArray, i: Int, j: Int): Long = a.getLong(i, j)
|
def get(a: INDArray, i: Int, j: Int): Long = a.getLong(i, j)
|
||||||
|
|
||||||
def get(a: INDArray, i: Int*): Long = a.getLong(i.map(_.toLong): _*)
|
def get(a: INDArray, i: Long*): Long = a.getLong(i: _*)
|
||||||
|
|
||||||
def create(arr: Array[Long]): INDArray = arr.toNDArray
|
def create(arr: Array[Long]): INDArray = arr.toNDArray
|
||||||
|
|
||||||
|
@ -523,11 +534,13 @@ case object ByteNDArrayEvidence extends IntegerNDArrayEvidence[Byte] {
|
||||||
|
|
||||||
def variance(ndarray: INDArray): Byte = ndarray.varNumber().byteValue()
|
def variance(ndarray: INDArray): Byte = ndarray.varNumber().byteValue()
|
||||||
|
|
||||||
|
def get(a: INDArray, i: Long): Byte = a.getInt(i.toInt).toByte
|
||||||
|
|
||||||
def get(a: INDArray, i: Int): Byte = a.getInt(i).toByte
|
def get(a: INDArray, i: Int): Byte = a.getInt(i).toByte
|
||||||
|
|
||||||
def get(a: INDArray, i: Int, j: Int): Byte = a.getInt(i, j).toByte
|
def get(a: INDArray, i: Int, j: Int): Byte = a.getInt(i, j).toByte
|
||||||
|
|
||||||
def get(a: INDArray, i: Int*): Byte = a.getInt(i.map(_.toInt): _*).toByte
|
def get(a: INDArray, i: Long*): Byte = a.getInt(i.map(_.toInt): _*).toByte
|
||||||
|
|
||||||
def create(arr: Array[Byte]): INDArray = arr.toNDArray
|
def create(arr: Array[Byte]): INDArray = arr.toNDArray
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
package org.nd4s
|
package org.nd4s
|
||||||
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray
|
import org.nd4j.linalg.api.ndarray.INDArray
|
||||||
import org.nd4j.linalg.indexing.INDArrayIndex
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Scala DSL for arrays
|
* Scala DSL for arrays
|
||||||
|
@ -121,7 +120,7 @@ trait OperatableNDArray[A <: INDArray] {
|
||||||
ev.get(underlying, i, j)
|
ev.get(underlying, i, j)
|
||||||
|
|
||||||
def get[B](indices: Int*)(implicit ev: NDArrayEvidence[A, B]): B =
|
def get[B](indices: Int*)(implicit ev: NDArrayEvidence[A, B]): B =
|
||||||
ev.get(underlying, indices: _*)
|
ev.get(underlying, indices.map(_.toLong): _*)
|
||||||
|
|
||||||
def apply[B](i: Int)(implicit ev: NDArrayEvidence[A, B]): B = get(i)
|
def apply[B](i: Int)(implicit ev: NDArrayEvidence[A, B]): B = get(i)
|
||||||
|
|
||||||
|
@ -129,10 +128,10 @@ trait OperatableNDArray[A <: INDArray] {
|
||||||
get(i, j)
|
get(i, j)
|
||||||
|
|
||||||
def apply[B](indices: Int*)(implicit ev: NDArrayEvidence[A, B]): B =
|
def apply[B](indices: Int*)(implicit ev: NDArrayEvidence[A, B]): B =
|
||||||
get(indices: _*)
|
ev.get(underlying, indices.map(_.toLong): _*)
|
||||||
|
|
||||||
def get[B](indices: Array[Int])(implicit ev: NDArrayEvidence[A, B]): B =
|
def get[B](indices: Array[Int])(implicit ev: NDArrayEvidence[A, B]): B =
|
||||||
ev.get(underlying, indices: _*)
|
ev.get(underlying, indices.map(_.toLong): _*)
|
||||||
|
|
||||||
def unary_-(): INDArray = underlying.neg()
|
def unary_-(): INDArray = underlying.neg()
|
||||||
|
|
||||||
|
|
|
@ -16,8 +16,7 @@
|
||||||
package org.nd4s.samediff
|
package org.nd4s.samediff
|
||||||
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray
|
import org.nd4j.linalg.api.ndarray.INDArray
|
||||||
import org.nd4j.autodiff.samediff.SDVariable
|
import org.nd4j.autodiff.samediff.{ SDIndex, SDVariable, SameDiff }
|
||||||
import org.nd4j.autodiff.samediff.SameDiff
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType
|
import org.nd4j.linalg.api.buffer.DataType
|
||||||
import org.nd4j.linalg.factory.Nd4j
|
import org.nd4j.linalg.factory.Nd4j
|
||||||
|
|
||||||
|
@ -49,7 +48,7 @@ class SameDiffWrapper {
|
||||||
sd.`var`(name, dataType, shape: _*)
|
sd.`var`(name, dataType, shape: _*)
|
||||||
|
|
||||||
def placeHolder(name: String, dataType: DataType, shape: Long*): SDVariable =
|
def placeHolder(name: String, dataType: DataType, shape: Long*): SDVariable =
|
||||||
sd.placeHolder("ph1", DataType.FLOAT, 3, 4)
|
sd.placeHolder(name, dataType, shape: _*)
|
||||||
}
|
}
|
||||||
|
|
||||||
class SDVariableWrapper {
|
class SDVariableWrapper {
|
||||||
|
@ -62,6 +61,10 @@ class SDVariableWrapper {
|
||||||
thisVariable = variable
|
thisVariable = variable
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def apply(index: Long): SDVariable = thisVariable.get(SDIndex.point(index))
|
||||||
|
|
||||||
|
def add(other: Double): Unit = thisVariable.add(other)
|
||||||
|
|
||||||
def *(other: SDVariable): SDVariable =
|
def *(other: SDVariable): SDVariable =
|
||||||
thisVariable.mul(other)
|
thisVariable.mul(other)
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,7 @@ class DSLSpec extends FlatSpec with Matchers {
|
||||||
|
|
||||||
// This test just verifies that an INDArray gets wrapped with an implicit conversion
|
// This test just verifies that an INDArray gets wrapped with an implicit conversion
|
||||||
|
|
||||||
val nd = Nd4j.create(Array[Float](1, 2), Array(2, 1): _*)
|
val nd = Nd4j.create(Array[Float](1, 2), Array(2, 1))
|
||||||
val nd1 = nd + 10L // + creates new array, += modifies in place
|
val nd1 = nd + 10L // + creates new array, += modifies in place
|
||||||
|
|
||||||
nd.get(0) should equal(1)
|
nd.get(0) should equal(1)
|
||||||
|
|
|
@ -220,23 +220,23 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest =>
|
||||||
val list = (0 to 9).toNDArray
|
val list = (0 to 9).toNDArray
|
||||||
val step = list(1 -> 7 by 2).reshape(-1)
|
val step = list(1 -> 7 by 2).reshape(-1)
|
||||||
assert(step.length() == 3)
|
assert(step.length() == 3)
|
||||||
assert(step.getFloat(0) == 1)
|
assert(step.getFloat(0: Long) == 1)
|
||||||
assert(step(0) == 1)
|
assert(step(0) == 1)
|
||||||
assert(step(0, 0) == 1)
|
assert(step(0, 0) == 1)
|
||||||
assert(step.getFloat(1) == 3)
|
assert(step.getFloat(1: Long) == 3)
|
||||||
assert(step.getFloat(2) == 5)
|
assert(step.getFloat(2: Long) == 5)
|
||||||
|
|
||||||
val filtered = list(-2 -> 10).reshape(-1)
|
val filtered = list(-2 -> 10).reshape(-1)
|
||||||
assert(filtered.length() == 2)
|
assert(filtered.length() == 2)
|
||||||
assert(filtered.getFloat(0) == 8)
|
assert(filtered.getFloat(0: Long) == 8)
|
||||||
assert(filtered.getFloat(1) == 9)
|
assert(filtered.getFloat(1: Long) == 9)
|
||||||
|
|
||||||
val nStep = list(-3 -> 3 by -1).reshape(-1)
|
val nStep = list(-3 -> 3 by -1).reshape(-1)
|
||||||
assert(nStep.length() == 4)
|
assert(nStep.length() == 4)
|
||||||
assert(nStep.getFloat(0) == 7)
|
assert(nStep.getFloat(0: Long) == 7)
|
||||||
assert(nStep.getFloat(1) == 6)
|
assert(nStep.getFloat(1: Long) == 6)
|
||||||
assert(nStep.getFloat(2) == 5)
|
assert(nStep.getFloat(2: Long) == 5)
|
||||||
assert(nStep.getFloat(3) == 4)
|
assert(nStep.getFloat(3: Long) == 4)
|
||||||
}
|
}
|
||||||
|
|
||||||
it should "be able to update value with specified indices" in {
|
it should "be able to update value with specified indices" in {
|
||||||
|
|
|
@ -26,27 +26,27 @@ import org.scalatest.{ FlatSpec, Matchers }
|
||||||
class OperatableNDArrayTest extends FlatSpec with Matchers {
|
class OperatableNDArrayTest extends FlatSpec with Matchers {
|
||||||
"RichNDArray" should "use the apply method to access values" in {
|
"RichNDArray" should "use the apply method to access values" in {
|
||||||
// -- 2D array
|
// -- 2D array
|
||||||
val nd2 = Nd4j.create(Array[Double](1, 2, 3, 4), Array[Int](1, 4): _*)
|
val nd2 = Nd4j.create(Array[Double](1, 2, 3, 4), Array[Int](1, 4))
|
||||||
|
|
||||||
nd2.get(0) should be(1)
|
nd2.get(0) should be(1)
|
||||||
nd2.get(0, 3) should be(4)
|
nd2.get(0, 3) should be(4)
|
||||||
|
|
||||||
// -- 3D array
|
// -- 3D array
|
||||||
val nd3 = Nd4j.create(Array[Double](1, 2, 3, 4, 5, 6, 7, 8), Array[Int](2, 2, 2): _*)
|
val nd3 = Nd4j.create(Array[Double](1, 2, 3, 4, 5, 6, 7, 8), Array[Int](2, 2, 2))
|
||||||
nd3.get(0, 0, 0) should be(1)
|
nd3.get(0, 0, 0) should be(1)
|
||||||
nd3.get(1, 1, 1) should be(8)
|
nd3.get(1, 1, 1) should be(8)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
it should "use transpose abbreviation" in {
|
it should "use transpose abbreviation" in {
|
||||||
val nd1 = Nd4j.create(Array[Double](1, 2, 3), Array(3, 1): _*)
|
val nd1 = Nd4j.create(Array[Double](1, 2, 3), Array(3, 1))
|
||||||
nd1.shape should equal(Array(3, 1))
|
nd1.shape should equal(Array(3, 1))
|
||||||
val nd1t = nd1.T
|
val nd1t = nd1.T
|
||||||
nd1t.shape should equal(Array(1, 3))
|
nd1t.shape should equal(Array(1, 3))
|
||||||
}
|
}
|
||||||
|
|
||||||
it should "add correctly" in {
|
it should "add correctly" in {
|
||||||
val a = Nd4j.create(Array[Double](1, 2, 3, 4, 5, 6, 7, 8), Array(2, 2, 2): _*)
|
val a = Nd4j.create(Array[Double](1, 2, 3, 4, 5, 6, 7, 8), Array(2, 2, 2))
|
||||||
val b = a + 100
|
val b = a + 100
|
||||||
a.get(0, 0, 0) should be(1)
|
a.get(0, 0, 0) should be(1)
|
||||||
b.get(0, 0, 0) should be(101)
|
b.get(0, 0, 0) should be(101)
|
||||||
|
@ -55,7 +55,7 @@ class OperatableNDArrayTest extends FlatSpec with Matchers {
|
||||||
}
|
}
|
||||||
|
|
||||||
it should "subtract correctly" in {
|
it should "subtract correctly" in {
|
||||||
val a = Nd4j.create(Array[Double](1, 2, 3, 4, 5, 6, 7, 8), Array(2, 2, 2): _*)
|
val a = Nd4j.create(Array[Double](1, 2, 3, 4, 5, 6, 7, 8), Array(2, 2, 2))
|
||||||
val b = a - 100
|
val b = a - 100
|
||||||
a.get(0, 0, 0) should be(1)
|
a.get(0, 0, 0) should be(1)
|
||||||
b.get(0, 0, 0) should be(-99)
|
b.get(0, 0, 0) should be(-99)
|
||||||
|
@ -69,7 +69,7 @@ class OperatableNDArrayTest extends FlatSpec with Matchers {
|
||||||
}
|
}
|
||||||
|
|
||||||
it should "divide correctly" in {
|
it should "divide correctly" in {
|
||||||
val a = Nd4j.create(Array[Double](1, 2, 3, 4, 5, 6, 7, 8), Array(2, 2, 2): _*)
|
val a = Nd4j.create(Array[Double](1, 2, 3, 4, 5, 6, 7, 8), Array(2, 2, 2))
|
||||||
val b = a / a
|
val b = a / a
|
||||||
a.get(1, 1, 1) should be(8)
|
a.get(1, 1, 1) should be(8)
|
||||||
b.get(1, 1, 1) should be(1)
|
b.get(1, 1, 1) should be(1)
|
||||||
|
@ -78,7 +78,7 @@ class OperatableNDArrayTest extends FlatSpec with Matchers {
|
||||||
}
|
}
|
||||||
|
|
||||||
it should "element-by-element multiply correctly" in {
|
it should "element-by-element multiply correctly" in {
|
||||||
val a = Nd4j.create(Array[Double](1, 2, 3, 4), Array(4, 1): _*)
|
val a = Nd4j.create(Array[Double](1, 2, 3, 4), Array(4, 1))
|
||||||
val b = a * a
|
val b = a * a
|
||||||
a.get(3) should be(4) // [1.0, 2.0, 3.0, 4.0
|
a.get(3) should be(4) // [1.0, 2.0, 3.0, 4.0
|
||||||
b.get(3) should be(16) // [1.0 ,4.0 ,9.0 ,16.0]
|
b.get(3) should be(16) // [1.0 ,4.0 ,9.0 ,16.0]
|
||||||
|
@ -87,7 +87,7 @@ class OperatableNDArrayTest extends FlatSpec with Matchers {
|
||||||
}
|
}
|
||||||
|
|
||||||
it should "use the update method to mutate values" in {
|
it should "use the update method to mutate values" in {
|
||||||
val nd3 = Nd4j.create(Array[Double](1, 2, 3, 4, 5, 6, 7, 8), Array(2, 2, 2): _*)
|
val nd3 = Nd4j.create(Array[Double](1, 2, 3, 4, 5, 6, 7, 8), Array(2, 2, 2))
|
||||||
nd3(0) = 11
|
nd3(0) = 11
|
||||||
nd3.get(0) should be(11)
|
nd3.get(0) should be(11)
|
||||||
|
|
||||||
|
|
|
@ -15,10 +15,13 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
package org.nd4s.samediff
|
package org.nd4s.samediff
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.{ SDVariable, SameDiff }
|
import org.nd4j.autodiff.samediff.{ SDVariable, SameDiff, TrainingConfig }
|
||||||
import org.nd4j.linalg.api.buffer.DataType
|
import org.nd4j.linalg.api.buffer.DataType
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray
|
import org.nd4j.linalg.api.ndarray.INDArray
|
||||||
|
import org.nd4j.linalg.dataset.MultiDataSet
|
||||||
|
import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator
|
||||||
import org.nd4j.linalg.factory.Nd4j
|
import org.nd4j.linalg.factory.Nd4j
|
||||||
|
import org.nd4j.linalg.learning.config.Sgd
|
||||||
import org.nd4s.Implicits._
|
import org.nd4s.Implicits._
|
||||||
import org.nd4s.samediff.implicits.Implicits._
|
import org.nd4s.samediff.implicits.Implicits._
|
||||||
import org.scalatest.{ FlatSpec, Matchers }
|
import org.scalatest.{ FlatSpec, Matchers }
|
||||||
|
@ -114,4 +117,62 @@ class ConstructionTest extends FlatSpec with Matchers {
|
||||||
var evaluated3 = w3.eval.castTo(DataType.DOUBLE)
|
var evaluated3 = w3.eval.castTo(DataType.DOUBLE)
|
||||||
evaluated3.toFloatVector.head shouldBe 16.0
|
evaluated3.toFloatVector.head shouldBe 16.0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
"classification example" should "work" in {
|
||||||
|
val learning_rate = 0.1
|
||||||
|
val seed = 7
|
||||||
|
|
||||||
|
val target = Nd4j.createUninitialized(1000)
|
||||||
|
val rng = Nd4j.getRandom
|
||||||
|
rng.setSeed(seed)
|
||||||
|
val x1_label1 = Nd4j.randn(3.0, 1.0, target, rng)
|
||||||
|
val target1 = Nd4j.createUninitialized(1000)
|
||||||
|
val x2_label1 = Nd4j.randn(2.0, 1.0, target1, rng)
|
||||||
|
val target2 = Nd4j.createUninitialized(1000)
|
||||||
|
val x1_label2 = Nd4j.randn(7.0, 1.0, target2, rng)
|
||||||
|
val target3 = Nd4j.createUninitialized(1000)
|
||||||
|
val x2_label2 = Nd4j.randn(6.0, 1.0, target3, rng)
|
||||||
|
|
||||||
|
// np.append, was not able to guess proper method
|
||||||
|
val x1s = Nd4j.concat(0, x1_label1, x1_label2)
|
||||||
|
val x2s = Nd4j.concat(0, x2_label1, x2_label2)
|
||||||
|
|
||||||
|
// Must have implicit sd here for some ops
|
||||||
|
implicit val sd = SameDiff.create
|
||||||
|
val ys = (Nd4j.scalar(0.0) * x1_label1.length()) + (Nd4j.scalar(1.0) * x1_label2.length())
|
||||||
|
|
||||||
|
// Empty shape can't be passed vs tf behaviour
|
||||||
|
val X1 = sd.placeHolder("x1", DataType.DOUBLE, 2000)
|
||||||
|
val X2 = sd.placeHolder("x2", DataType.DOUBLE, 2000)
|
||||||
|
val y = sd.placeHolder("y", DataType.DOUBLE)
|
||||||
|
val w = sd.bind("w", DataType.DOUBLE, Array[Int](3))
|
||||||
|
//Sample: -tf.log(y_model * Y + (1 — y_model) * (1 — Y))
|
||||||
|
val y_model: SDVariable =
|
||||||
|
sd.nn.sigmoid(w(2) * X2 + w(1) * X1 + w(0))
|
||||||
|
val cost_fun: SDVariable = (sd.math.neg(
|
||||||
|
sd.math.log(y_model * y + (sd.math.log(sd.constant(1.0) - y_model) * (sd.constant(1.0) - y)))
|
||||||
|
))
|
||||||
|
val loss = sd.mean("loss", cost_fun)
|
||||||
|
|
||||||
|
val updater = new Sgd(learning_rate)
|
||||||
|
|
||||||
|
sd.setLossVariables("loss")
|
||||||
|
sd.createGradFunction
|
||||||
|
val conf = new TrainingConfig.Builder()
|
||||||
|
.updater(updater)
|
||||||
|
.minimize("loss")
|
||||||
|
.dataSetFeatureMapping("x1", "x2", "y")
|
||||||
|
.markLabelsUnused()
|
||||||
|
.build()
|
||||||
|
|
||||||
|
val mds = new MultiDataSet(Array[INDArray](x1s, x2s, ys), new Array[INDArray](0))
|
||||||
|
|
||||||
|
sd.setTrainingConfig(conf)
|
||||||
|
sd.fit(new SingletonMultiDataSetIterator(mds), 1)
|
||||||
|
|
||||||
|
w.eval.toDoubleVector.head shouldBe (0.0629 +- 0.0001)
|
||||||
|
w.eval.toDoubleVector.tail.head shouldBe (0.3128 +- 0.0001)
|
||||||
|
w.eval.toDoubleVector.tail.tail.head shouldBe (0.2503 +- 0.0001)
|
||||||
|
//Console.println(w.eval)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
package org.nd4s.samediff
|
package org.nd4s.samediff
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.{ SDVariable, SameDiff }
|
import org.nd4j.autodiff.samediff.{ SDIndex, SDVariable, SameDiff }
|
||||||
import org.nd4j.linalg.api.buffer.DataType
|
import org.nd4j.linalg.api.buffer.DataType
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray
|
import org.nd4j.linalg.api.ndarray.INDArray
|
||||||
import org.nd4j.linalg.factory.Nd4j
|
import org.nd4j.linalg.factory.Nd4j
|
||||||
|
@ -200,4 +200,14 @@ class MathTest extends FlatSpec with Matchers {
|
||||||
val w3 = w1 >> two
|
val w3 = w1 >> two
|
||||||
w3.eval.toIntVector.head shouldBe 4
|
w3.eval.toIntVector.head shouldBe 4
|
||||||
}
|
}
|
||||||
|
|
||||||
|
"SDVariable " should "be indexable" in {
|
||||||
|
implicit val sd = SameDiff.create
|
||||||
|
|
||||||
|
val arr = Nd4j.linspace(1, 100, 100).reshape('c', 10L, 10L)
|
||||||
|
val x = sd.`var`(arr)
|
||||||
|
val y = new SDVariableWrapper(x)
|
||||||
|
|
||||||
|
x.get(SDIndex.point(0)).getArr shouldBe y(0).getArr
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue