Merge remote-tracking branch 'konduit/master'
commit
2844f8b69a
|
@ -124,7 +124,6 @@ public class TestUtils {
|
||||||
public static INDArray randomOneHot(long examples, long nOut, Random rng){
|
public static INDArray randomOneHot(long examples, long nOut, Random rng){
|
||||||
INDArray arr = Nd4j.create(examples, nOut);
|
INDArray arr = Nd4j.create(examples, nOut);
|
||||||
for( int i=0; i<examples; i++ ){
|
for( int i=0; i<examples; i++ ){
|
||||||
// FIXME: int cast
|
|
||||||
arr.putScalar(i, rng.nextInt((int) nOut), 1.0);
|
arr.putScalar(i, rng.nextInt((int) nOut), 1.0);
|
||||||
}
|
}
|
||||||
return arr;
|
return arr;
|
||||||
|
|
|
@ -0,0 +1,107 @@
|
||||||
|
package org.deeplearning4j;
|
||||||
|
|
||||||
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.WorkspaceMode;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
|
||||||
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.nn.layers.mkldnn.MKLDNNBatchNormHelper;
|
||||||
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
|
import java.lang.reflect.Field;
|
||||||
|
|
||||||
|
import static junit.framework.TestCase.*;
|
||||||
|
|
||||||
|
public class TestBatchNormBp {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void test(){
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
// INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 4, 4);
|
||||||
|
INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 15, 15);
|
||||||
|
INDArray mean = in.mean(0, 2, 3); //Nd4j.rand(DataType.FLOAT, 3);
|
||||||
|
INDArray var = in.var(0, 2, 3); //Nd4j.rand(DataType.FLOAT, 3);
|
||||||
|
INDArray eps = Nd4j.rand(DataType.FLOAT, in.shape());
|
||||||
|
// INDArray gamma = Nd4j.ones(DataType.FLOAT, 3);
|
||||||
|
// INDArray beta = Nd4j.zeros(DataType.FLOAT, 3);
|
||||||
|
INDArray gamma = Nd4j.rand(DataType.FLOAT, 3);
|
||||||
|
INDArray beta = Nd4j.rand(DataType.FLOAT, 3);
|
||||||
|
double e = 1e-5;
|
||||||
|
|
||||||
|
INDArray dLdIn = in.ulike();
|
||||||
|
INDArray dLdm = mean.ulike();
|
||||||
|
INDArray dLdv = var.ulike();
|
||||||
|
INDArray dLdg = gamma.ulike();
|
||||||
|
INDArray dLdb = beta.ulike();
|
||||||
|
|
||||||
|
DynamicCustomOp op = DynamicCustomOp.builder("batchnorm_bp")
|
||||||
|
.addInputs(in, mean, var, eps, gamma, beta)
|
||||||
|
.addIntegerArguments(
|
||||||
|
1, //Apply scale
|
||||||
|
1, //Apply beta
|
||||||
|
1) //Axis (NCHW)
|
||||||
|
.addFloatingPointArguments(e)
|
||||||
|
.addOutputs(dLdIn, dLdm, dLdv, dLdg, dLdb)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
Nd4j.exec(op);
|
||||||
|
System.out.println(dLdIn);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void compareImpls() throws Exception {
|
||||||
|
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 15, 15);
|
||||||
|
INDArray mean = in.mean(0, 2, 3).reshape(1,3);
|
||||||
|
INDArray var = in.var(0, 2, 3).reshape(1,3);
|
||||||
|
INDArray eps = Nd4j.rand(DataType.FLOAT, in.shape());
|
||||||
|
INDArray gamma = Nd4j.rand(DataType.FLOAT, 1,3);
|
||||||
|
INDArray beta = Nd4j.rand(DataType.FLOAT, 1,3);
|
||||||
|
double e = 1e-3;
|
||||||
|
|
||||||
|
INDArray dLdIn = in.ulike();
|
||||||
|
INDArray dLdm = mean.ulike();
|
||||||
|
INDArray dLdv = var.ulike();
|
||||||
|
INDArray dLdg = gamma.ulike();
|
||||||
|
INDArray dLdb = beta.ulike();
|
||||||
|
|
||||||
|
|
||||||
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.inferenceWorkspaceMode(WorkspaceMode.NONE)
|
||||||
|
.trainingWorkspaceMode(WorkspaceMode.NONE)
|
||||||
|
.list()
|
||||||
|
.layer(new BatchNormalization.Builder().nIn(3).nOut(3).build())
|
||||||
|
.build();
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
net.init();
|
||||||
|
org.deeplearning4j.nn.layers.normalization.BatchNormalization bn = (org.deeplearning4j.nn.layers.normalization.BatchNormalization) net.getLayer(0);
|
||||||
|
assertNotNull(bn.getHelper());
|
||||||
|
Field f = bn.getClass().getDeclaredField("helper");
|
||||||
|
f.setAccessible(true);
|
||||||
|
f.set(bn, null);
|
||||||
|
assertNull(bn.getHelper());
|
||||||
|
|
||||||
|
|
||||||
|
MKLDNNBatchNormHelper h = new MKLDNNBatchNormHelper(DataType.FLOAT);
|
||||||
|
|
||||||
|
net.output(in, true);
|
||||||
|
bn.setInput(in, LayerWorkspaceMgr.noWorkspaces());
|
||||||
|
Pair<Gradient,INDArray> p = net.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces());
|
||||||
|
|
||||||
|
h.preOutput(in, true, new long[]{1,3}, gamma, beta, mean, var, 0.5, e, LayerWorkspaceMgr.noWorkspaces());
|
||||||
|
Pair<Gradient,INDArray> pmkl = h.backpropGradient(in, eps, new long[]{1,3}, gamma, beta, dLdg, dLdb, e, LayerWorkspaceMgr.noWorkspaces());
|
||||||
|
|
||||||
|
INDArray dldin_dl4j = p.getSecond();
|
||||||
|
|
||||||
|
System.out.println("dl4j == mkldnn: " + p.getSecond().equals(pmkl.getSecond()));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -132,7 +132,6 @@ public class TestUtils {
|
||||||
public static INDArray randomOneHot(long examples, long nOut, Random rng){
|
public static INDArray randomOneHot(long examples, long nOut, Random rng){
|
||||||
INDArray arr = Nd4j.create(examples, nOut);
|
INDArray arr = Nd4j.create(examples, nOut);
|
||||||
for( int i=0; i<examples; i++ ){
|
for( int i=0; i<examples; i++ ){
|
||||||
// FIXME: int cast
|
|
||||||
arr.putScalar(i, rng.nextInt((int) nOut), 1.0);
|
arr.putScalar(i, rng.nextInt((int) nOut), 1.0);
|
||||||
}
|
}
|
||||||
return arr;
|
return arr;
|
||||||
|
|
|
@ -187,7 +187,6 @@ public class AsyncDataSetIteratorTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testVariableTimeSeries2() throws Exception {
|
public void testVariableTimeSeries2() throws Exception {
|
||||||
AsyncDataSetIterator adsi =
|
AsyncDataSetIterator adsi =
|
||||||
|
|
|
@ -36,6 +36,7 @@ import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||||
|
import org.deeplearning4j.optimize.listeners.CollectScoresIterationListener;
|
||||||
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
||||||
import org.junit.Ignore;
|
import org.junit.Ignore;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
@ -242,7 +243,10 @@ public class DataSetIteratorTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork model = new MultiLayerNetwork(builder.build());
|
MultiLayerNetwork model = new MultiLayerNetwork(builder.build());
|
||||||
model.init();
|
model.init();
|
||||||
|
|
||||||
model.setListeners(Arrays.asList((TrainingListener) new ScoreIterationListener(listenerFreq)));
|
//model.setListeners(Arrays.asList((TrainingListener) new ScoreIterationListener(listenerFreq)));
|
||||||
|
|
||||||
|
CollectScoresIterationListener listener = new CollectScoresIterationListener(listenerFreq);
|
||||||
|
model.setListeners(listener);
|
||||||
|
|
||||||
model.fit(cifar);
|
model.fit(cifar);
|
||||||
|
|
||||||
|
@ -254,6 +258,7 @@ public class DataSetIteratorTest extends BaseDL4JTest {
|
||||||
eval.eval(testDS.getLabels(), output);
|
eval.eval(testDS.getLabels(), output);
|
||||||
}
|
}
|
||||||
System.out.println(eval.stats(true));
|
System.out.println(eval.stats(true));
|
||||||
|
listener.exportScores(System.out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -464,13 +464,11 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
|
||||||
ret[1] = Nd4j.zeros(labelsShape);
|
ret[1] = Nd4j.zeros(labelsShape);
|
||||||
if (labelsShape.length == 2) {
|
if (labelsShape.length == 2) {
|
||||||
for (int i = 0; i < labelsShape[0]; i++) {
|
for (int i = 0; i < labelsShape[0]; i++) {
|
||||||
// FIXME: int cast
|
|
||||||
ret[1].putScalar(i, r.nextInt((int) labelsShape[1]), 1.0);
|
ret[1].putScalar(i, r.nextInt((int) labelsShape[1]), 1.0);
|
||||||
}
|
}
|
||||||
} else if (labelsShape.length == 3) {
|
} else if (labelsShape.length == 3) {
|
||||||
for (int i = 0; i < labelsShape[0]; i++) {
|
for (int i = 0; i < labelsShape[0]; i++) {
|
||||||
for (int j = 0; j < labelsShape[2]; j++) {
|
for (int j = 0; j < labelsShape[2]; j++) {
|
||||||
// FIXME: int cast
|
|
||||||
ret[1].putScalar(i, r.nextInt((int) labelsShape[1]), j, 1.0);
|
ret[1].putScalar(i, r.nextInt((int) labelsShape[1]), j, 1.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -484,13 +482,11 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
|
||||||
ret[1] = Nd4j.ones(labelsShape);
|
ret[1] = Nd4j.ones(labelsShape);
|
||||||
if (labelsShape.length == 2) {
|
if (labelsShape.length == 2) {
|
||||||
for (int i = 0; i < labelsShape[0]; i++) {
|
for (int i = 0; i < labelsShape[0]; i++) {
|
||||||
// FIXME: int cast
|
|
||||||
ret[1].putScalar(i, r.nextInt((int) labelsShape[1]), -1.0);
|
ret[1].putScalar(i, r.nextInt((int) labelsShape[1]), -1.0);
|
||||||
}
|
}
|
||||||
} else if (labelsShape.length == 3) {
|
} else if (labelsShape.length == 3) {
|
||||||
for (int i = 0; i < labelsShape[0]; i++) {
|
for (int i = 0; i < labelsShape[0]; i++) {
|
||||||
for (int j = 0; j < labelsShape[2]; j++) {
|
for (int j = 0; j < labelsShape[2]; j++) {
|
||||||
// FIXME: int cast
|
|
||||||
ret[1].putScalar(i, r.nextInt((int) labelsShape[1]), j, -1.0);
|
ret[1].putScalar(i, r.nextInt((int) labelsShape[1]), j, -1.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -176,8 +176,7 @@ public class ShiftVertexTest extends BaseDL4JTest {
|
||||||
manual_weights.put("output_b", c);
|
manual_weights.put("output_b", c);
|
||||||
|
|
||||||
// First things first, let's calculate the score.
|
// First things first, let's calculate the score.
|
||||||
// FIXME: int cast
|
long batchsz = input.shape()[0];
|
||||||
int batchsz = (int) input.shape()[0];
|
|
||||||
INDArray z = input.castTo(W.dataType()).mmul(W).add(b.repmat(batchsz, 1));
|
INDArray z = input.castTo(W.dataType()).mmul(W).add(b.repmat(batchsz, 1));
|
||||||
INDArray a = a1.getActivation(z.dup(), true).add(sf); // activation modifies it's input!!
|
INDArray a = a1.getActivation(z.dup(), true).add(sf); // activation modifies it's input!!
|
||||||
INDArray q = a.mmul(V).add(c.repmat(batchsz, 1));
|
INDArray q = a.mmul(V).add(c.repmat(batchsz, 1));
|
||||||
|
|
|
@ -157,8 +157,8 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMultiLayerNetworkFrozenLayerParamsAfterBackprop() {
|
public void testMultiLayerNetworkFrozenLayerParamsAfterBackprop() {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
DataSet randomData = new DataSet(Nd4j.rand(100, 4, 12345), Nd4j.rand(100, 1, 12345));
|
DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1));
|
||||||
|
|
||||||
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder()
|
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder()
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
|
@ -194,8 +194,9 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testComputationGraphFrozenLayerParamsAfterBackprop() {
|
public void testComputationGraphFrozenLayerParamsAfterBackprop() {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
DataSet randomData = new DataSet(Nd4j.rand(100, 4,12345), Nd4j.rand(100, 1, 12345));
|
DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1));
|
||||||
String frozenBranchName = "B1-";
|
String frozenBranchName = "B1-";
|
||||||
String unfrozenBranchName = "B2-";
|
String unfrozenBranchName = "B2-";
|
||||||
|
|
||||||
|
@ -254,43 +255,18 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
|
||||||
*/
|
*/
|
||||||
@Test
|
@Test
|
||||||
public void testFrozenLayerVsSgd() {
|
public void testFrozenLayerVsSgd() {
|
||||||
DataSet randomData = new DataSet(Nd4j.rand(100, 4, 12345), Nd4j.rand(100, 1, 12345));
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1));
|
||||||
|
|
||||||
MultiLayerConfiguration confSgd = new NeuralNetConfiguration.Builder()
|
MultiLayerConfiguration confSgd = new NeuralNetConfiguration.Builder()
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.updater(new Sgd(2))
|
.updater(new Sgd(2))
|
||||||
.list()
|
.list()
|
||||||
.layer(0,
|
.layer(0,new DenseLayer.Builder().nIn(4).nOut(3).build())
|
||||||
new DenseLayer.Builder()
|
.layer(1,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(3).nOut(4).build())
|
||||||
.nIn(4)
|
.layer(2,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(4).nOut(2).build())
|
||||||
.nOut(3)
|
.layer(3,new OutputLayer.Builder(LossFunctions.LossFunction.MSE).updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).activation(Activation.TANH).nIn(2).nOut(1).build())
|
||||||
.build()
|
|
||||||
)
|
|
||||||
.layer(1,
|
|
||||||
new DenseLayer.Builder()
|
|
||||||
.updater(new Sgd(0.0))
|
|
||||||
.biasUpdater(new Sgd(0.0))
|
|
||||||
.nIn(3)
|
|
||||||
.nOut(4)
|
|
||||||
.build()
|
|
||||||
).layer(2,
|
|
||||||
new DenseLayer.Builder()
|
|
||||||
.updater(new Sgd(0.0))
|
|
||||||
.biasUpdater(new Sgd(0.0))
|
|
||||||
.nIn(4)
|
|
||||||
.nOut(2)
|
|
||||||
.build()
|
|
||||||
|
|
||||||
).layer(3,
|
|
||||||
new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
|
|
||||||
.updater(new Sgd(0.0))
|
|
||||||
.biasUpdater(new Sgd(0.0))
|
|
||||||
.activation(Activation.TANH)
|
|
||||||
.nIn(2)
|
|
||||||
.nOut(1)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
MultiLayerConfiguration confFrozen = new NeuralNetConfiguration.Builder()
|
MultiLayerConfiguration confFrozen = new NeuralNetConfiguration.Builder()
|
||||||
|
@ -298,36 +274,10 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.updater(new Sgd(2))
|
.updater(new Sgd(2))
|
||||||
.list()
|
.list()
|
||||||
.layer(0,
|
.layer(0,new DenseLayer.Builder().nIn(4).nOut(3).build())
|
||||||
new DenseLayer.Builder()
|
.layer(1,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(3).nOut(4).build()))
|
||||||
.nIn(4)
|
.layer(2,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(4).nOut(2).build()))
|
||||||
.nOut(3)
|
.layer(3,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(2).nOut(1).build()))
|
||||||
.build()
|
|
||||||
)
|
|
||||||
.layer(1,
|
|
||||||
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
|
||||||
new DenseLayer.Builder()
|
|
||||||
.nIn(3)
|
|
||||||
.nOut(4)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.layer(2,
|
|
||||||
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
|
||||||
new DenseLayer.Builder()
|
|
||||||
.nIn(4)
|
|
||||||
.nOut(2)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
).layer(3,
|
|
||||||
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
|
||||||
new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
|
|
||||||
.activation(Activation.TANH)
|
|
||||||
.nIn(2)
|
|
||||||
.nOut(1)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.build();
|
.build();
|
||||||
MultiLayerNetwork frozenNetwork = new MultiLayerNetwork(confFrozen);
|
MultiLayerNetwork frozenNetwork = new MultiLayerNetwork(confFrozen);
|
||||||
frozenNetwork.init();
|
frozenNetwork.init();
|
||||||
|
@ -359,8 +309,8 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testComputationGraphVsSgd() {
|
public void testComputationGraphVsSgd() {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
DataSet randomData = new DataSet(Nd4j.rand(100, 4, 12345), Nd4j.rand(100, 1, 12345));
|
DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1));
|
||||||
String frozenBranchName = "B1-";
|
String frozenBranchName = "B1-";
|
||||||
String unfrozenBranchName = "B2-";
|
String unfrozenBranchName = "B2-";
|
||||||
|
|
||||||
|
@ -381,71 +331,19 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.graphBuilder()
|
.graphBuilder()
|
||||||
.addInputs("input")
|
.addInputs("input")
|
||||||
.addLayer(initialLayer,
|
.addLayer(initialLayer,new DenseLayer.Builder().nIn(4).nOut(4).build(),"input")
|
||||||
new DenseLayer.Builder()
|
.addLayer(frozenBranchUnfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(3).build(), initialLayer)
|
||||||
.nIn(4)
|
.addLayer(frozenBranchFrozenLayer1,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
||||||
.nOut(4)
|
new DenseLayer.Builder().nIn(3).nOut(4).build()),frozenBranchUnfrozenLayer0)
|
||||||
.build(),
|
|
||||||
"input"
|
|
||||||
)
|
|
||||||
.addLayer(frozenBranchUnfrozenLayer0,
|
|
||||||
new DenseLayer.Builder()
|
|
||||||
.nIn(4)
|
|
||||||
.nOut(3)
|
|
||||||
.build(),
|
|
||||||
initialLayer
|
|
||||||
)
|
|
||||||
.addLayer(frozenBranchFrozenLayer1,
|
|
||||||
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
|
||||||
new DenseLayer.Builder()
|
|
||||||
.nIn(3)
|
|
||||||
.nOut(4)
|
|
||||||
.build()
|
|
||||||
),
|
|
||||||
frozenBranchUnfrozenLayer0
|
|
||||||
)
|
|
||||||
.addLayer(frozenBranchFrozenLayer2,
|
.addLayer(frozenBranchFrozenLayer2,
|
||||||
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
||||||
new DenseLayer.Builder()
|
new DenseLayer.Builder().nIn(4).nOut(2).build()),frozenBranchFrozenLayer1)
|
||||||
.nIn(4)
|
.addLayer(unfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(4).build(),initialLayer)
|
||||||
.nOut(2)
|
.addLayer(unfrozenLayer1,new DenseLayer.Builder().nIn(4).nOut(2).build(),unfrozenLayer0)
|
||||||
.build()
|
.addLayer(unfrozenBranch2,new DenseLayer.Builder().nIn(2).nOut(1).build(),unfrozenLayer1)
|
||||||
),
|
.addVertex("merge",new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2)
|
||||||
frozenBranchFrozenLayer1
|
.addLayer(frozenBranchOutput, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
||||||
)
|
new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(3).nOut(1).build()),"merge")
|
||||||
.addLayer(unfrozenLayer0,
|
|
||||||
new DenseLayer.Builder()
|
|
||||||
.nIn(4)
|
|
||||||
.nOut(4)
|
|
||||||
.build(),
|
|
||||||
initialLayer
|
|
||||||
)
|
|
||||||
.addLayer(unfrozenLayer1,
|
|
||||||
new DenseLayer.Builder()
|
|
||||||
.nIn(4)
|
|
||||||
.nOut(2)
|
|
||||||
.build(),
|
|
||||||
unfrozenLayer0
|
|
||||||
)
|
|
||||||
.addLayer(unfrozenBranch2,
|
|
||||||
new DenseLayer.Builder()
|
|
||||||
.nIn(2)
|
|
||||||
.nOut(1)
|
|
||||||
.build(),
|
|
||||||
unfrozenLayer1
|
|
||||||
)
|
|
||||||
.addVertex("merge",
|
|
||||||
new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2)
|
|
||||||
.addLayer(frozenBranchOutput,
|
|
||||||
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
|
|
||||||
new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
|
|
||||||
.activation(Activation.TANH)
|
|
||||||
.nIn(3)
|
|
||||||
.nOut(1)
|
|
||||||
.build()
|
|
||||||
),
|
|
||||||
"merge"
|
|
||||||
)
|
|
||||||
.setOutputs(frozenBranchOutput)
|
.setOutputs(frozenBranchOutput)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
@ -454,73 +352,15 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.graphBuilder()
|
.graphBuilder()
|
||||||
.addInputs("input")
|
.addInputs("input")
|
||||||
.addLayer(initialLayer,
|
.addLayer(initialLayer, new DenseLayer.Builder().nIn(4).nOut(4).build(),"input")
|
||||||
new DenseLayer.Builder()
|
.addLayer(frozenBranchUnfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(3).build(),initialLayer)
|
||||||
.nIn(4)
|
.addLayer(frozenBranchFrozenLayer1,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(3).nOut(4).build(),frozenBranchUnfrozenLayer0)
|
||||||
.nOut(4)
|
.addLayer(frozenBranchFrozenLayer2,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(4).nOut(2).build(),frozenBranchFrozenLayer1)
|
||||||
.build(),
|
.addLayer(unfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(4).build(),initialLayer)
|
||||||
"input"
|
.addLayer(unfrozenLayer1,new DenseLayer.Builder().nIn(4).nOut(2).build(),unfrozenLayer0)
|
||||||
)
|
.addLayer(unfrozenBranch2,new DenseLayer.Builder().nIn(2).nOut(1).build(),unfrozenLayer1)
|
||||||
.addLayer(frozenBranchUnfrozenLayer0,
|
.addVertex("merge",new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2)
|
||||||
new DenseLayer.Builder()
|
.addLayer(frozenBranchOutput,new OutputLayer.Builder(LossFunctions.LossFunction.MSE).updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).activation(Activation.TANH).nIn(3).nOut(1).build(),"merge")
|
||||||
.nIn(4)
|
|
||||||
.nOut(3)
|
|
||||||
.build(),
|
|
||||||
initialLayer
|
|
||||||
)
|
|
||||||
.addLayer(frozenBranchFrozenLayer1,
|
|
||||||
new DenseLayer.Builder()
|
|
||||||
.updater(new Sgd(0.0))
|
|
||||||
.biasUpdater(new Sgd(0.0))
|
|
||||||
.nIn(3)
|
|
||||||
.nOut(4)
|
|
||||||
.build(),
|
|
||||||
frozenBranchUnfrozenLayer0
|
|
||||||
)
|
|
||||||
.addLayer(frozenBranchFrozenLayer2,
|
|
||||||
new DenseLayer.Builder()
|
|
||||||
.updater(new Sgd(0.0))
|
|
||||||
.biasUpdater(new Sgd(0.0))
|
|
||||||
.nIn(4)
|
|
||||||
.nOut(2)
|
|
||||||
.build()
|
|
||||||
,
|
|
||||||
frozenBranchFrozenLayer1
|
|
||||||
)
|
|
||||||
.addLayer(unfrozenLayer0,
|
|
||||||
new DenseLayer.Builder()
|
|
||||||
.nIn(4)
|
|
||||||
.nOut(4)
|
|
||||||
.build(),
|
|
||||||
initialLayer
|
|
||||||
)
|
|
||||||
.addLayer(unfrozenLayer1,
|
|
||||||
new DenseLayer.Builder()
|
|
||||||
.nIn(4)
|
|
||||||
.nOut(2)
|
|
||||||
.build(),
|
|
||||||
unfrozenLayer0
|
|
||||||
)
|
|
||||||
.addLayer(unfrozenBranch2,
|
|
||||||
new DenseLayer.Builder()
|
|
||||||
.nIn(2)
|
|
||||||
.nOut(1)
|
|
||||||
.build(),
|
|
||||||
unfrozenLayer1
|
|
||||||
)
|
|
||||||
.addVertex("merge",
|
|
||||||
new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2)
|
|
||||||
.addLayer(frozenBranchOutput,
|
|
||||||
new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
|
|
||||||
.updater(new Sgd(0.0))
|
|
||||||
.biasUpdater(new Sgd(0.0))
|
|
||||||
.activation(Activation.TANH)
|
|
||||||
.nIn(3)
|
|
||||||
.nOut(1)
|
|
||||||
.build()
|
|
||||||
,
|
|
||||||
"merge"
|
|
||||||
)
|
|
||||||
.setOutputs(frozenBranchOutput)
|
.setOutputs(frozenBranchOutput)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
import org.deeplearning4j.eval.Evaluation;
|
import org.deeplearning4j.eval.Evaluation;
|
||||||
import org.deeplearning4j.exception.DL4JException;
|
import org.deeplearning4j.exception.DL4JException;
|
||||||
|
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
|
@ -693,4 +694,22 @@ public class ConvolutionLayerTest extends BaseDL4JTest {
|
||||||
INDArray out = net.output(in);
|
INDArray out = net.output(in);
|
||||||
assertArrayEquals(new long[]{2,7,6}, out.shape());
|
assertArrayEquals(new long[]{2,7,6}, out.shape());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testDeconvBadInput(){
|
||||||
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.list()
|
||||||
|
.layer(new Deconvolution2D.Builder().nIn(5).nOut(3).build())
|
||||||
|
.build();
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
net.init();
|
||||||
|
|
||||||
|
INDArray badInput = Nd4j.create(DataType.FLOAT, 1, 10, 5, 5);
|
||||||
|
try {
|
||||||
|
net.output(badInput);
|
||||||
|
} catch (DL4JInvalidInputException e){
|
||||||
|
String msg = e.getMessage();
|
||||||
|
assertTrue(msg,msg.contains("Deconvolution2D") && msg.contains("input") && msg.contains("channels"));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -80,6 +80,7 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSameDiffDenseForward() {
|
public void testSameDiffDenseForward() {
|
||||||
|
for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) {
|
||||||
for (int minibatch : new int[]{5, 1}) {
|
for (int minibatch : new int[]{5, 1}) {
|
||||||
int nIn = 3;
|
int nIn = 3;
|
||||||
int nOut = 4;
|
int nOut = 4;
|
||||||
|
@ -97,8 +98,10 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
||||||
};
|
};
|
||||||
|
|
||||||
for (Activation a : afns) {
|
for (Activation a : afns) {
|
||||||
log.info("Starting test - " + a);
|
log.info("Starting test - " + a + ", workspace = " + wsm);
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.inferenceWorkspaceMode(wsm)
|
||||||
|
.trainingWorkspaceMode(wsm)
|
||||||
.list()
|
.list()
|
||||||
.layer(new SameDiffDense.Builder().nIn(nIn).nOut(nOut)
|
.layer(new SameDiffDense.Builder().nIn(nIn).nOut(nOut)
|
||||||
.activation(a)
|
.activation(a)
|
||||||
|
@ -146,9 +149,11 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSameDiffDenseForwardMultiLayer() {
|
public void testSameDiffDenseForwardMultiLayer() {
|
||||||
|
for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) {
|
||||||
for (int minibatch : new int[]{5, 1}) {
|
for (int minibatch : new int[]{5, 1}) {
|
||||||
int nIn = 3;
|
int nIn = 3;
|
||||||
int nOut = 4;
|
int nOut = 4;
|
||||||
|
@ -166,7 +171,7 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
||||||
};
|
};
|
||||||
|
|
||||||
for (Activation a : afns) {
|
for (Activation a : afns) {
|
||||||
log.info("Starting test - " + a);
|
log.info("Starting test - " + a + " - workspace=" + wsm);
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.list()
|
.list()
|
||||||
|
@ -201,7 +206,6 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
||||||
net2.init();
|
net2.init();
|
||||||
|
|
||||||
// net.params().assign(net2.params());
|
|
||||||
assertEquals(net2.params(), net.params());
|
assertEquals(net2.params(), net.params());
|
||||||
|
|
||||||
//Check params:
|
//Check params:
|
||||||
|
@ -231,6 +235,7 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSameDiffDenseBackward() {
|
public void testSameDiffDenseBackward() {
|
||||||
|
@ -244,10 +249,13 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
||||||
Activation[] afns = new Activation[]{
|
Activation[] afns = new Activation[]{
|
||||||
Activation.TANH,
|
Activation.TANH,
|
||||||
Activation.SIGMOID,
|
Activation.SIGMOID,
|
||||||
Activation.ELU, Activation.IDENTITY, Activation.SOFTPLUS, Activation.SOFTSIGN,
|
Activation.ELU,
|
||||||
|
Activation.IDENTITY,
|
||||||
|
Activation.SOFTPLUS,
|
||||||
|
Activation.SOFTSIGN,
|
||||||
Activation.HARDTANH,
|
Activation.HARDTANH,
|
||||||
Activation.CUBE, //https://github.com/deeplearning4j/nd4j/issues/2426
|
Activation.CUBE,
|
||||||
Activation.RELU //JVM crash
|
Activation.RELU
|
||||||
};
|
};
|
||||||
|
|
||||||
for (Activation a : afns) {
|
for (Activation a : afns) {
|
||||||
|
@ -337,12 +345,13 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
||||||
|
|
||||||
int nIn = 4;
|
int nIn = 4;
|
||||||
int nOut = 3;
|
int nOut = 3;
|
||||||
boolean workspaces = true;
|
|
||||||
|
for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) {
|
||||||
|
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.trainingWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE)
|
.trainingWorkspaceMode(wsm)
|
||||||
.inferenceWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE)
|
.inferenceWorkspaceMode(wsm)
|
||||||
.updater(new Adam(0.1))
|
.updater(new Adam(0.1))
|
||||||
.list()
|
.list()
|
||||||
.layer(new SameDiffDense.Builder().nIn(nIn).nOut(5).activation(Activation.TANH).build())
|
.layer(new SameDiffDense.Builder().nIn(nIn).nOut(5).activation(Activation.TANH).build())
|
||||||
|
@ -373,7 +382,7 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
||||||
assertEquals(netStandard.params(), netSD.params());
|
assertEquals(netStandard.params(), netSD.params());
|
||||||
assertEquals(netStandard.paramTable(), netSD.paramTable());
|
assertEquals(netStandard.paramTable(), netSD.paramTable());
|
||||||
|
|
||||||
DataSetIterator iter = new IrisDataSetIterator(150,150);
|
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||||
DataSet ds = iter.next();
|
DataSet ds = iter.next();
|
||||||
|
|
||||||
INDArray outSD = netSD.output(ds.getFeatures());
|
INDArray outSD = netSD.output(ds.getFeatures());
|
||||||
|
@ -381,7 +390,7 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
||||||
|
|
||||||
assertEquals(outStd, outSD);
|
assertEquals(outStd, outSD);
|
||||||
|
|
||||||
for( int i=0; i<3; i++ ){
|
for (int i = 0; i < 3; i++) {
|
||||||
netSD.fit(ds);
|
netSD.fit(ds);
|
||||||
netStandard.fit(ds);
|
netStandard.fit(ds);
|
||||||
String s = String.valueOf(i);
|
String s = String.valueOf(i);
|
||||||
|
@ -396,13 +405,14 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
||||||
INDArray outMb = netStandard.output(newIn);
|
INDArray outMb = netStandard.output(newIn);
|
||||||
assertEquals(outMb, outMbsd);
|
assertEquals(outMb, outMbsd);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void gradientCheck() {
|
public void gradientCheck() {
|
||||||
int nIn = 4;
|
int nIn = 4;
|
||||||
int nOut = 4;
|
int nOut = 4;
|
||||||
|
|
||||||
for (boolean workspaces : new boolean[]{false, true}) {
|
for (boolean workspaces : new boolean[]{true, false}) {
|
||||||
for (Activation a : new Activation[]{Activation.TANH, Activation.IDENTITY}) {
|
for (Activation a : new Activation[]{Activation.TANH, Activation.IDENTITY}) {
|
||||||
|
|
||||||
String msg = "workspaces: " + workspaces + ", " + a;
|
String msg = "workspaces: " + workspaces + ", " + a;
|
||||||
|
|
|
@ -21,6 +21,7 @@ import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.TestUtils;
|
import org.deeplearning4j.TestUtils;
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.WorkspaceMode;
|
||||||
import org.deeplearning4j.nn.conf.graph.ElementWiseVertex;
|
import org.deeplearning4j.nn.conf.graph.ElementWiseVertex;
|
||||||
import org.deeplearning4j.nn.conf.graph.ScaleVertex;
|
import org.deeplearning4j.nn.conf.graph.ScaleVertex;
|
||||||
import org.deeplearning4j.nn.conf.graph.ShiftVertex;
|
import org.deeplearning4j.nn.conf.graph.ShiftVertex;
|
||||||
|
@ -52,8 +53,14 @@ public class TestSameDiffLambda extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSameDiffLamdaLayerBasic(){
|
public void testSameDiffLamdaLayerBasic(){
|
||||||
|
for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) {
|
||||||
|
log.info("--- Workspace Mode: {} ---", wsm);
|
||||||
|
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.trainingWorkspaceMode(wsm)
|
||||||
|
.inferenceWorkspaceMode(wsm)
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.updater(new Adam(0.01))
|
.updater(new Adam(0.01))
|
||||||
.graphBuilder()
|
.graphBuilder()
|
||||||
|
@ -67,6 +74,8 @@ public class TestSameDiffLambda extends BaseDL4JTest {
|
||||||
|
|
||||||
//Equavalent, not using SameDiff Lambda:
|
//Equavalent, not using SameDiff Lambda:
|
||||||
ComputationGraphConfiguration confStd = new NeuralNetConfiguration.Builder()
|
ComputationGraphConfiguration confStd = new NeuralNetConfiguration.Builder()
|
||||||
|
.trainingWorkspaceMode(wsm)
|
||||||
|
.inferenceWorkspaceMode(wsm)
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.updater(new Adam(0.01))
|
.updater(new Adam(0.01))
|
||||||
.graphBuilder()
|
.graphBuilder()
|
||||||
|
@ -87,7 +96,7 @@ public class TestSameDiffLambda extends BaseDL4JTest {
|
||||||
|
|
||||||
lambda.setParams(std.params());
|
lambda.setParams(std.params());
|
||||||
|
|
||||||
INDArray in = Nd4j.rand(3,5);
|
INDArray in = Nd4j.rand(3, 5);
|
||||||
INDArray labels = TestUtils.randomOneHot(3, 5);
|
INDArray labels = TestUtils.randomOneHot(3, 5);
|
||||||
DataSet ds = new DataSet(in, labels);
|
DataSet ds = new DataSet(in, labels);
|
||||||
|
|
||||||
|
@ -101,7 +110,7 @@ public class TestSameDiffLambda extends BaseDL4JTest {
|
||||||
|
|
||||||
assertEquals(scoreStd, scoreLambda, 1e-6);
|
assertEquals(scoreStd, scoreLambda, 1e-6);
|
||||||
|
|
||||||
for( int i=0; i<3; i++ ){
|
for (int i = 0; i < 3; i++) {
|
||||||
lambda.fit(ds);
|
lambda.fit(ds);
|
||||||
std.fit(ds);
|
std.fit(ds);
|
||||||
|
|
||||||
|
@ -122,11 +131,17 @@ public class TestSameDiffLambda extends BaseDL4JTest {
|
||||||
INDArray outMb = std.output(newIn)[0];
|
INDArray outMb = std.output(newIn)[0];
|
||||||
assertEquals(outMb, outMbsd);
|
assertEquals(outMb, outMbsd);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSameDiffLamdaVertexBasic(){
|
public void testSameDiffLamdaVertexBasic(){
|
||||||
|
for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) {
|
||||||
|
log.info("--- Workspace Mode: {} ---", wsm);
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.trainingWorkspaceMode(wsm)
|
||||||
|
.inferenceWorkspaceMode(wsm)
|
||||||
.dataType(DataType.DOUBLE)
|
.dataType(DataType.DOUBLE)
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.updater(new Adam(0.01))
|
.updater(new Adam(0.01))
|
||||||
|
@ -142,6 +157,8 @@ public class TestSameDiffLambda extends BaseDL4JTest {
|
||||||
|
|
||||||
//Equavalent, not using SameDiff Lambda:
|
//Equavalent, not using SameDiff Lambda:
|
||||||
ComputationGraphConfiguration confStd = new NeuralNetConfiguration.Builder()
|
ComputationGraphConfiguration confStd = new NeuralNetConfiguration.Builder()
|
||||||
|
.trainingWorkspaceMode(wsm)
|
||||||
|
.inferenceWorkspaceMode(wsm)
|
||||||
.dataType(DataType.DOUBLE)
|
.dataType(DataType.DOUBLE)
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.updater(new Adam(0.01))
|
.updater(new Adam(0.01))
|
||||||
|
@ -163,8 +180,8 @@ public class TestSameDiffLambda extends BaseDL4JTest {
|
||||||
|
|
||||||
lambda.setParams(std.params());
|
lambda.setParams(std.params());
|
||||||
|
|
||||||
INDArray in1 = Nd4j.rand(3,5);
|
INDArray in1 = Nd4j.rand(3, 5);
|
||||||
INDArray in2 = Nd4j.rand(3,5);
|
INDArray in2 = Nd4j.rand(3, 5);
|
||||||
INDArray labels = TestUtils.randomOneHot(3, 5);
|
INDArray labels = TestUtils.randomOneHot(3, 5);
|
||||||
MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[]{in1, in2}, new INDArray[]{labels});
|
MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[]{in1, in2}, new INDArray[]{labels});
|
||||||
|
|
||||||
|
@ -178,7 +195,7 @@ public class TestSameDiffLambda extends BaseDL4JTest {
|
||||||
|
|
||||||
assertEquals(scoreStd, scoreLambda, 1e-6);
|
assertEquals(scoreStd, scoreLambda, 1e-6);
|
||||||
|
|
||||||
for( int i=0; i<3; i++ ){
|
for (int i = 0; i < 3; i++) {
|
||||||
lambda.fit(mds);
|
lambda.fit(mds);
|
||||||
std.fit(mds);
|
std.fit(mds);
|
||||||
|
|
||||||
|
@ -200,4 +217,5 @@ public class TestSameDiffLambda extends BaseDL4JTest {
|
||||||
INDArray outMb = std.output(newIn1, newIn2)[0];
|
INDArray outMb = std.output(newIn1, newIn2)[0];
|
||||||
assertEquals(outMb, outMbsd);
|
assertEquals(outMb, outMbsd);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,10 +23,13 @@ import org.deeplearning4j.datasets.iterator.impl.SingletonDataSetIterator;
|
||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.WorkspaceMode;
|
||||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.junit.Ignore;
|
import org.junit.Ignore;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
@ -36,10 +39,13 @@ import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.learning.config.Adam;
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
|
import java.lang.reflect.Field;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
|
||||||
|
import static junit.framework.TestCase.*;
|
||||||
import static org.junit.Assume.assumeTrue;
|
import static org.junit.Assume.assumeTrue;
|
||||||
|
|
||||||
public class ValidateMKLDNN extends BaseDL4JTest {
|
public class ValidateMKLDNN extends BaseDL4JTest {
|
||||||
|
@ -148,7 +154,7 @@ public class ValidateMKLDNN extends BaseDL4JTest {
|
||||||
.padding(0, 0)
|
.padding(0, 0)
|
||||||
.nOut(3)
|
.nOut(3)
|
||||||
.build())
|
.build())
|
||||||
.layer(new BatchNormalization.Builder().cudnnAllowFallback(false).build())
|
.layer(new BatchNormalization.Builder().helperAllowFallback(false)/*.eps(0)*/.build())
|
||||||
.layer(new ConvolutionLayer.Builder().activation(Activation.TANH)
|
.layer(new ConvolutionLayer.Builder().activation(Activation.TANH)
|
||||||
.kernelSize(kernel)
|
.kernelSize(kernel)
|
||||||
.stride(stride)
|
.stride(stride)
|
||||||
|
@ -256,4 +262,54 @@ public class ValidateMKLDNN extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void compareBatchNormBackward() throws Exception {
|
||||||
|
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 15, 15);
|
||||||
|
INDArray mean = in.mean(0, 2, 3).reshape(1,3);
|
||||||
|
INDArray var = in.var(0, 2, 3).reshape(1,3);
|
||||||
|
INDArray eps = Nd4j.rand(DataType.FLOAT, in.shape());
|
||||||
|
INDArray gamma = Nd4j.rand(DataType.FLOAT, 1,3);
|
||||||
|
INDArray beta = Nd4j.rand(DataType.FLOAT, 1,3);
|
||||||
|
double e = 1e-3;
|
||||||
|
|
||||||
|
INDArray dLdIn = in.ulike();
|
||||||
|
INDArray dLdm = mean.ulike();
|
||||||
|
INDArray dLdv = var.ulike();
|
||||||
|
INDArray dLdg = gamma.ulike();
|
||||||
|
INDArray dLdb = beta.ulike();
|
||||||
|
|
||||||
|
|
||||||
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.inferenceWorkspaceMode(WorkspaceMode.NONE)
|
||||||
|
.trainingWorkspaceMode(WorkspaceMode.NONE)
|
||||||
|
.list()
|
||||||
|
.layer(new BatchNormalization.Builder().nIn(3).nOut(3).build())
|
||||||
|
.build();
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
net.init();
|
||||||
|
org.deeplearning4j.nn.layers.normalization.BatchNormalization bn = (org.deeplearning4j.nn.layers.normalization.BatchNormalization) net.getLayer(0);
|
||||||
|
assertNotNull(bn.getHelper());
|
||||||
|
System.out.println(bn.getHelper());
|
||||||
|
|
||||||
|
net.output(in, true);
|
||||||
|
bn.setInput(in, LayerWorkspaceMgr.noWorkspaces());
|
||||||
|
Pair<Gradient,INDArray> pcudnn = net.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces());
|
||||||
|
|
||||||
|
Field f = bn.getClass().getDeclaredField("helper");
|
||||||
|
f.setAccessible(true);
|
||||||
|
f.set(bn, null);
|
||||||
|
assertNull(bn.getHelper());
|
||||||
|
|
||||||
|
net.output(in, true);
|
||||||
|
bn.setInput(in, LayerWorkspaceMgr.noWorkspaces());
|
||||||
|
Pair<Gradient,INDArray> p = net.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces());
|
||||||
|
|
||||||
|
INDArray dldin_dl4j = p.getSecond();
|
||||||
|
INDArray dldin_helper = pcudnn.getSecond();
|
||||||
|
|
||||||
|
assertTrue(dldin_dl4j.equalsWithEps(dldin_helper, 1e-5));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,6 +36,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
import org.nd4j.linalg.exception.ND4JArraySizeException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.learning.config.Sgd;
|
import org.nd4j.linalg.learning.config.Sgd;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
|
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
|
||||||
|
@ -340,7 +341,8 @@ public class BackPropMLPTest extends BaseDL4JTest {
|
||||||
|
|
||||||
public static float[] asFloat(INDArray arr) {
|
public static float[] asFloat(INDArray arr) {
|
||||||
long len = arr.length();
|
long len = arr.length();
|
||||||
// FIXME: int cast
|
if (len > Integer.MAX_VALUE)
|
||||||
|
throw new ND4JArraySizeException();
|
||||||
float[] f = new float[(int) len];
|
float[] f = new float[(int) len];
|
||||||
NdIndexIterator iterator = new NdIndexIterator('c', arr.shape());
|
NdIndexIterator iterator = new NdIndexIterator('c', arr.shape());
|
||||||
for (int i = 0; i < len; i++) {
|
for (int i = 0; i < len; i++) {
|
||||||
|
|
|
@ -320,7 +320,6 @@ public class MultiLayerTest extends BaseDL4JTest {
|
||||||
public static float[] asFloat(INDArray arr) {
|
public static float[] asFloat(INDArray arr) {
|
||||||
long len = arr.length();
|
long len = arr.length();
|
||||||
|
|
||||||
// FIXME: int cast
|
|
||||||
float[] f = new float[(int) len];
|
float[] f = new float[(int) len];
|
||||||
for (int i = 0; i < len; i++)
|
for (int i = 0; i < len; i++)
|
||||||
f[i] = arr.getFloat(i);
|
f[i] = arr.getFloat(i);
|
||||||
|
|
|
@ -331,7 +331,6 @@ public class TestUpdaters extends BaseDL4JTest {
|
||||||
double calculatedByHandMScalar = 0.2;
|
double calculatedByHandMScalar = 0.2;
|
||||||
double[] expectedM = Nd4j.ones(1, numParams).mul(calculatedByHandMScalar).data().asDouble();
|
double[] expectedM = Nd4j.ones(1, numParams).mul(calculatedByHandMScalar).data().asDouble();
|
||||||
|
|
||||||
// FIXME: int cast
|
|
||||||
double[] actualM = Arrays.copyOfRange(nadamUpdater.getM().data().asDouble(), 0, (int) numParams);
|
double[] actualM = Arrays.copyOfRange(nadamUpdater.getM().data().asDouble(), 0, (int) numParams);
|
||||||
for (int i = 0; i < actualM.length; i++) {
|
for (int i = 0; i < actualM.length; i++) {
|
||||||
actualM[i] = Math.round(actualM[i] * 1e2) / 1e2;
|
actualM[i] = Math.round(actualM[i] * 1e2) / 1e2;
|
||||||
|
|
|
@ -48,6 +48,7 @@ import org.nd4j.linalg.api.rng.DefaultRandom;
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
import org.nd4j.linalg.exception.ND4JArraySizeException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.conditions.Condition;
|
import org.nd4j.linalg.indexing.conditions.Condition;
|
||||||
import org.nd4j.linalg.learning.config.AdaGrad;
|
import org.nd4j.linalg.learning.config.AdaGrad;
|
||||||
|
@ -664,8 +665,10 @@ public class TestOptimizers extends BaseDL4JTest {
|
||||||
double xlm1 = parameters.getDouble(nDims - 2);
|
double xlm1 = parameters.getDouble(nDims - 2);
|
||||||
double gl = 200 * (xl - xlm1 * xlm1);
|
double gl = 200 * (xl - xlm1 * xlm1);
|
||||||
|
|
||||||
// FIXME: int cast
|
if (nDims - 1 > Integer.MAX_VALUE) {
|
||||||
gradient.put(0, (int) nDims - 1, gl);
|
throw new ND4JArraySizeException();
|
||||||
|
}
|
||||||
|
gradient.put(0, (int)nDims - 1, gl);
|
||||||
Gradient g = new DefaultGradient();
|
Gradient g = new DefaultGradient();
|
||||||
g.gradientForVariable().put("W", gradient);
|
g.gradientForVariable().put("W", gradient);
|
||||||
this.gradient = g;
|
this.gradient = g;
|
||||||
|
@ -865,8 +868,7 @@ public class TestOptimizers extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long numParams() {
|
public long numParams() {
|
||||||
// FIXME: int cast
|
return parameters.length();
|
||||||
return (int) parameters.length();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -86,10 +86,10 @@ public class CompareTrainingImplementations extends BaseDL4JTest {
|
||||||
SDVariable label = sd.placeHolder("label", DataType.DOUBLE, -1, 3);
|
SDVariable label = sd.placeHolder("label", DataType.DOUBLE, -1, 3);
|
||||||
|
|
||||||
SDVariable w0 = sd.var("w0", new XavierInitScheme('c', 4, 10), DataType.DOUBLE, 4, 10);
|
SDVariable w0 = sd.var("w0", new XavierInitScheme('c', 4, 10), DataType.DOUBLE, 4, 10);
|
||||||
SDVariable b0 = sd.zero("b0", 1, 10);
|
SDVariable b0 = sd.var("b0", Nd4j.create(DataType.DOUBLE, 1, 10));
|
||||||
|
|
||||||
SDVariable w1 = sd.var("w1", new XavierInitScheme('c', 10, 3), DataType.DOUBLE, 10, 3);
|
SDVariable w1 = sd.var("w1", new XavierInitScheme('c', 10, 3), DataType.DOUBLE, 10, 3);
|
||||||
SDVariable b1 = sd.zero("b1", 1, 3);
|
SDVariable b1 = sd.var("b1", Nd4j.create(DataType.DOUBLE, 1, 3));
|
||||||
|
|
||||||
SDVariable z0 = in.mmul(w0).add(b0);
|
SDVariable z0 = in.mmul(w0).add(b0);
|
||||||
SDVariable a0 = sd.nn().tanh(z0);
|
SDVariable a0 = sd.nn().tanh(z0);
|
||||||
|
@ -172,8 +172,8 @@ public class CompareTrainingImplementations extends BaseDL4JTest {
|
||||||
Map<String,INDArray> placeholders = new HashMap<>();
|
Map<String,INDArray> placeholders = new HashMap<>();
|
||||||
placeholders.put("input", f);
|
placeholders.put("input", f);
|
||||||
placeholders.put("label", l);
|
placeholders.put("label", l);
|
||||||
sd.exec(placeholders, lossMse.getVarName());
|
Map<String,INDArray> map = sd.output(placeholders, lossMse.name(), a1.name());
|
||||||
INDArray outSd = a1.getArr();
|
INDArray outSd = map.get(a1.name());
|
||||||
INDArray outDl4j = net.output(f);
|
INDArray outDl4j = net.output(f);
|
||||||
|
|
||||||
assertEquals(testName, outDl4j, outSd);
|
assertEquals(testName, outDl4j, outSd);
|
||||||
|
@ -187,7 +187,7 @@ public class CompareTrainingImplementations extends BaseDL4JTest {
|
||||||
|
|
||||||
//Check score
|
//Check score
|
||||||
double scoreDl4j = net.score();
|
double scoreDl4j = net.score();
|
||||||
double scoreSd = lossMse.getArr().getDouble(0) + sd.calcRegularizationScore();
|
double scoreSd = map.get(lossMse.name()).getDouble(0) + sd.calcRegularizationScore();
|
||||||
assertEquals(testName, scoreDl4j, scoreSd, 1e-6);
|
assertEquals(testName, scoreDl4j, scoreSd, 1e-6);
|
||||||
|
|
||||||
double lossRegScoreSD = sd.calcRegularizationScore();
|
double lossRegScoreSD = sd.calcRegularizationScore();
|
||||||
|
@ -197,15 +197,15 @@ public class CompareTrainingImplementations extends BaseDL4JTest {
|
||||||
|
|
||||||
//Check gradients (before updater applied)
|
//Check gradients (before updater applied)
|
||||||
Map<String,INDArray> grads = net.gradient().gradientForVariable();
|
Map<String,INDArray> grads = net.gradient().gradientForVariable();
|
||||||
sd.execBackwards(placeholders);
|
Map<String,INDArray> gm = sd.calculateGradients(placeholders, b1.name(), w1.name(), b0.name(), w0.name());
|
||||||
|
|
||||||
//Note that the SameDiff gradients don't include the L1/L2 terms at present just from execBackwards()... these are added in fitting only
|
//Note that the SameDiff gradients don't include the L1/L2 terms at present just from execBackwards()... these are added in fitting only
|
||||||
//We can check correctness though with training param checks later
|
//We can check correctness though with training param checks later
|
||||||
if(l1Val == 0 && l2Val == 0 && wdVal == 0) {
|
if(l1Val == 0 && l2Val == 0 && wdVal == 0) {
|
||||||
assertEquals(testName, grads.get("1_b"), b1.getGradient().getArr());
|
assertEquals(testName, grads.get("1_b"), gm.get(b1.name()));
|
||||||
assertEquals(testName, grads.get("1_W"), w1.getGradient().getArr());
|
assertEquals(testName, grads.get("1_W"), gm.get(w1.name()));
|
||||||
assertEquals(testName, grads.get("0_b"), b0.getGradient().getArr());
|
assertEquals(testName, grads.get("0_b"), gm.get(b0.name()));
|
||||||
assertEquals(testName, grads.get("0_W"), w0.getGradient().getArr());
|
assertEquals(testName, grads.get("0_W"), gm.get(w0.name()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -123,7 +123,7 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, int[] shape, INDArray gamma,
|
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, INDArray beta,
|
||||||
INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr layerWorkspaceMgr) {
|
INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr layerWorkspaceMgr) {
|
||||||
this.eps = eps;
|
this.eps = eps;
|
||||||
val miniBatch = (int) input.size(0);
|
val miniBatch = (int) input.size(0);
|
||||||
|
@ -173,8 +173,8 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
|
||||||
|
|
||||||
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, depth, inH, inW,
|
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, depth, inH, inW,
|
||||||
dstStride[0], dstStride[1], dstStride[2], dstStride[3]));
|
dstStride[0], dstStride[1], dstStride[2], dstStride[3]));
|
||||||
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(gamma.data().dataType()), shape[0],
|
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(gamma.data().dataType()), (int)shape[0],
|
||||||
shape[1], shape.length > 2 ? shape[2] : 1, shape.length > 3 ? shape[3] : 1));
|
(int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1));
|
||||||
|
|
||||||
Allocator allocator = AtomicAllocator.getInstance();
|
Allocator allocator = AtomicAllocator.getInstance();
|
||||||
CudaContext context = allocator.getFlowController().prepareActionAllWrite(input, epsilon, nextEpsilon, gamma,
|
CudaContext context = allocator.getFlowController().prepareActionAllWrite(input, epsilon, nextEpsilon, gamma,
|
||||||
|
@ -189,7 +189,7 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
|
||||||
Pointer varCacheData = allocator.getPointer(varCache, context);
|
Pointer varCacheData = allocator.getPointer(varCache, context);
|
||||||
|
|
||||||
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
|
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
|
||||||
checkCudnn(cudnnBatchNormalizationBackward(cudnnContext, batchNormMode, alpha, beta, alpha, alpha,
|
checkCudnn(cudnnBatchNormalizationBackward(cudnnContext, batchNormMode, alpha, this.beta, alpha, alpha,
|
||||||
cudnnContext.srcTensorDesc, srcData, cudnnContext.deltaTensorDesc, epsData,
|
cudnnContext.srcTensorDesc, srcData, cudnnContext.deltaTensorDesc, epsData,
|
||||||
cudnnContext.dstTensorDesc, dstData, cudnnContext.gammaBetaTensorDesc, gammaData, dGammaData,
|
cudnnContext.dstTensorDesc, dstData, cudnnContext.gammaBetaTensorDesc, gammaData, dGammaData,
|
||||||
dBetaData, eps, meanCacheData, varCacheData));
|
dBetaData, eps, meanCacheData, varCacheData));
|
||||||
|
@ -214,7 +214,7 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray preOutput(INDArray x, boolean training, int[] shape, INDArray gamma, INDArray beta, INDArray mean,
|
public INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean,
|
||||||
INDArray var, double decay, double eps, LayerWorkspaceMgr workspaceMgr) {
|
INDArray var, double decay, double eps, LayerWorkspaceMgr workspaceMgr) {
|
||||||
this.eps = eps;
|
this.eps = eps;
|
||||||
final boolean isHalf = (x.dataType() == DataType.HALF);
|
final boolean isHalf = (x.dataType() == DataType.HALF);
|
||||||
|
@ -252,8 +252,8 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
|
||||||
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, inDepth, inH, inW,
|
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, inDepth, inH, inW,
|
||||||
dstStride[0], dstStride[1], dstStride[2], dstStride[3]));
|
dstStride[0], dstStride[1], dstStride[2], dstStride[3]));
|
||||||
|
|
||||||
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(mean.data().dataType()), shape[0],
|
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(mean.data().dataType()), (int)shape[0],
|
||||||
shape[1], shape.length > 2 ? shape[2] : 1, shape.length > 3 ? shape[3] : 1));
|
(int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1));
|
||||||
|
|
||||||
Allocator allocator = AtomicAllocator.getInstance();
|
Allocator allocator = AtomicAllocator.getInstance();
|
||||||
CudaContext context =
|
CudaContext context =
|
||||||
|
|
|
@ -286,7 +286,7 @@ public class RecordReaderMultiDataSetIterator implements MultiDataSetIterator, S
|
||||||
for (INDArray w : exampleData) {
|
for (INDArray w : exampleData) {
|
||||||
val n = w.size(0);
|
val n = w.size(0);
|
||||||
|
|
||||||
// FIXME: int cast
|
if (Math.min(minExamples, n) < Integer.MAX_VALUE)
|
||||||
minExamples = (int) Math.min(minExamples, n);
|
minExamples = (int) Math.min(minExamples, n);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -366,7 +366,6 @@ public class SequenceRecordReaderDataSetIterator implements DataSetIterator {
|
||||||
DataSet ds = mdsToDataSet(mds);
|
DataSet ds = mdsToDataSet(mds);
|
||||||
|
|
||||||
if (totalOutcomes == -1) {
|
if (totalOutcomes == -1) {
|
||||||
// FIXME: int cast
|
|
||||||
inputColumns = (int) ds.getFeatures().size(1);
|
inputColumns = (int) ds.getFeatures().size(1);
|
||||||
totalOutcomes = ds.getLabels() == null ? -1 : (int) ds.getLabels().size(1);
|
totalOutcomes = ds.getLabels() == null ? -1 : (int) ds.getLabels().size(1);
|
||||||
}
|
}
|
||||||
|
@ -394,7 +393,6 @@ public class SequenceRecordReaderDataSetIterator implements DataSetIterator {
|
||||||
stored = next();
|
stored = next();
|
||||||
useStored = true;
|
useStored = true;
|
||||||
|
|
||||||
// FIXME: int cast
|
|
||||||
inputColumns = (int) stored.getFeatures().size(1);
|
inputColumns = (int) stored.getFeatures().size(1);
|
||||||
totalOutcomes = (int) stored.getLabels().size(1);
|
totalOutcomes = (int) stored.getLabels().size(1);
|
||||||
}
|
}
|
||||||
|
|
|
@ -172,7 +172,6 @@ public abstract class AbstractDataSetIterator<T> implements DataSetIterator {
|
||||||
Pair<T, T> pair = iterator.next();
|
Pair<T, T> pair = iterator.next();
|
||||||
if (numFeatures < 1) {
|
if (numFeatures < 1) {
|
||||||
if (pair.getFirst() instanceof INDArray) {
|
if (pair.getFirst() instanceof INDArray) {
|
||||||
// FIXME: int cast
|
|
||||||
numFeatures = (int) ((INDArray) pair.getFirst()).length();
|
numFeatures = (int) ((INDArray) pair.getFirst()).length();
|
||||||
numLabels = (int) ((INDArray) pair.getSecond()).length();
|
numLabels = (int) ((INDArray) pair.getSecond()).length();
|
||||||
} else if (pair.getFirst() instanceof float[]) {
|
} else if (pair.getFirst() instanceof float[]) {
|
||||||
|
|
|
@ -95,7 +95,6 @@ public class IteratorDataSetIterator implements DataSetIterator {
|
||||||
//Set columns etc for later use
|
//Set columns etc for later use
|
||||||
DataSet temp = list.get(0);
|
DataSet temp = list.get(0);
|
||||||
|
|
||||||
// FIXME: int cast
|
|
||||||
inputColumns = (int) temp.getFeatures().size(1);
|
inputColumns = (int) temp.getFeatures().size(1);
|
||||||
totalOutcomes = temp.getLabels() == null ? 0 : (int) temp.getLabels().size(1); //May be null for layerwise pretraining
|
totalOutcomes = temp.getLabels() == null ? 0 : (int) temp.getLabels().size(1); //May be null for layerwise pretraining
|
||||||
}
|
}
|
||||||
|
|
|
@ -73,8 +73,7 @@ public class IteratorMultiDataSetIterator implements MultiDataSetIterator {
|
||||||
next = iterator.next();
|
next = iterator.next();
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME: int cast
|
long nExamples = next.getFeatures(0).size(0);
|
||||||
int nExamples = (int) next.getFeatures(0).size(0);
|
|
||||||
if (countSoFar + nExamples <= batchSize) {
|
if (countSoFar + nExamples <= batchSize) {
|
||||||
//Add the entire MultiDataSet as-is
|
//Add the entire MultiDataSet as-is
|
||||||
list.add(next);
|
list.add(next);
|
||||||
|
@ -140,7 +139,7 @@ public class IteratorMultiDataSetIterator implements MultiDataSetIterator {
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static INDArray getRange(INDArray arr, int exampleFrom, int exampleToExclusive) {
|
private static INDArray getRange(INDArray arr, long exampleFrom, long exampleToExclusive) {
|
||||||
if (arr == null)
|
if (arr == null)
|
||||||
return null;
|
return null;
|
||||||
|
|
||||||
|
|
|
@ -134,7 +134,7 @@ public abstract class BaseFileIterator<T, P> implements Iterator<T> {
|
||||||
List<T> remainder = new ArrayList<>();
|
List<T> remainder = new ArrayList<>();
|
||||||
int soFar = 0;
|
int soFar = 0;
|
||||||
for (T t : toMerge) {
|
for (T t : toMerge) {
|
||||||
int size = sizeOf(t);
|
long size = sizeOf(t);
|
||||||
|
|
||||||
if (soFar + size <= batchSize) {
|
if (soFar + size <= batchSize) {
|
||||||
correctNum.add(t);
|
correctNum.add(t);
|
||||||
|
@ -190,7 +190,7 @@ public abstract class BaseFileIterator<T, P> implements Iterator<T> {
|
||||||
|
|
||||||
protected abstract T load(File f);
|
protected abstract T load(File f);
|
||||||
|
|
||||||
protected abstract int sizeOf(T of);
|
protected abstract long sizeOf(T of);
|
||||||
|
|
||||||
protected abstract List<T> split(T toSplit);
|
protected abstract List<T> split(T toSplit);
|
||||||
|
|
||||||
|
|
|
@ -151,7 +151,7 @@ public class FileDataSetIterator extends BaseFileIterator<DataSet, DataSetPrePro
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected int sizeOf(DataSet of) {
|
protected long sizeOf(DataSet of) {
|
||||||
return of.numExamples();
|
return of.numExamples();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -151,9 +151,8 @@ public class FileMultiDataSetIterator extends BaseFileIterator<MultiDataSet, Mul
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected int sizeOf(MultiDataSet of) {
|
protected long sizeOf(MultiDataSet of) {
|
||||||
// FIXME: int cast
|
return of.getFeatures(0).size(0);
|
||||||
return (int) of.getFeatures(0).size(0);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -665,8 +665,7 @@ public class BarnesHutTsne implements Model {
|
||||||
|
|
||||||
if (useAdaGrad) {
|
if (useAdaGrad) {
|
||||||
if (adaGrad == null) {
|
if (adaGrad == null) {
|
||||||
// FIXME: int cast
|
adaGrad = new AdaGrad(gradient.shape(), learningRate);
|
||||||
adaGrad = new AdaGrad(ArrayUtil.toInts(gradient.shape()), learningRate);
|
|
||||||
adaGrad.setStateViewArray(Nd4j.zeros(gradient.shape()).reshape(1, gradChange.length()),
|
adaGrad.setStateViewArray(Nd4j.zeros(gradient.shape()).reshape(1, gradChange.length()),
|
||||||
gradChange.shape(), gradient.ordering(), true);
|
gradChange.shape(), gradient.ordering(), true);
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core;
|
||||||
|
|
||||||
|
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
|
import org.apache.commons.lang3.ArrayUtils;
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||||
|
@ -51,6 +52,13 @@ public class KerasReshape extends KerasLayer {
|
||||||
this(layerConfig, true);
|
this(layerConfig, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private long[] listToLongArray(List<Integer> list) {
|
||||||
|
long[] retVal = new long[list.size()];
|
||||||
|
for (int i = 0; i < list.size(); ++i) {
|
||||||
|
retVal[i] = list.get(i);
|
||||||
|
}
|
||||||
|
return retVal;
|
||||||
|
}
|
||||||
/**
|
/**
|
||||||
* Constructor from parsed Keras layer configuration dictionary.
|
* Constructor from parsed Keras layer configuration dictionary.
|
||||||
*
|
*
|
||||||
|
@ -67,9 +75,7 @@ public class KerasReshape extends KerasLayer {
|
||||||
if (innerConfig.containsKey(targetShape)) {
|
if (innerConfig.containsKey(targetShape)) {
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
List<Integer> targetShapeList = (List<Integer>) innerConfig.get(targetShape);
|
List<Integer> targetShapeList = (List<Integer>) innerConfig.get(targetShape);
|
||||||
|
this.targetShape = listToLongArray(targetShapeList);
|
||||||
// FIXME: int cast
|
|
||||||
this.targetShape = ArrayUtil.toLongArray(ArrayUtil.toArray(targetShapeList));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -690,13 +690,11 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
INDArray testLabels = Nd4j.create(predictionsDl4j.shape());
|
INDArray testLabels = Nd4j.create(predictionsDl4j.shape());
|
||||||
if (testLabels.rank() == 2) {
|
if (testLabels.rank() == 2) {
|
||||||
for (int i = 0; i < testLabels.size(0); i++) {
|
for (int i = 0; i < testLabels.size(0); i++) {
|
||||||
// FIXME: int cast
|
|
||||||
testLabels.putScalar(i, r.nextInt((int) testLabels.size(1)), 1.0);
|
testLabels.putScalar(i, r.nextInt((int) testLabels.size(1)), 1.0);
|
||||||
}
|
}
|
||||||
} else if (testLabels.rank() == 3) {
|
} else if (testLabels.rank() == 3) {
|
||||||
for (int i = 0; i < testLabels.size(0); i++) {
|
for (int i = 0; i < testLabels.size(0); i++) {
|
||||||
for (int j = 0; j < testLabels.size(1); j++) {
|
for (int j = 0; j < testLabels.size(1); j++) {
|
||||||
// FIXME: int cast
|
|
||||||
testLabels.putScalar(i, j, r.nextInt((int) testLabels.size(1)), 1.0);
|
testLabels.putScalar(i, j, r.nextInt((int) testLabels.size(1)), 1.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,9 @@ package org.deeplearning4j.clustering.kdtree;
|
||||||
|
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.KnnMinDistance;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -28,79 +31,103 @@ import java.util.List;
|
||||||
*/
|
*/
|
||||||
public class HyperRect implements Serializable {
|
public class HyperRect implements Serializable {
|
||||||
|
|
||||||
private List<Interval> points;
|
//private List<Interval> points;
|
||||||
|
private float[] lowerEnds;
|
||||||
|
private float[] higherEnds;
|
||||||
|
private INDArray lowerEndsIND;
|
||||||
|
private INDArray higherEndsIND;
|
||||||
|
|
||||||
public HyperRect(List<Interval> points) {
|
public HyperRect(float[] lowerEndsIn, float[] higherEndsIn) {
|
||||||
//this.points = points;
|
this.lowerEnds = new float[lowerEndsIn.length];
|
||||||
this.points = new ArrayList<>(points.size());
|
this.higherEnds = new float[lowerEndsIn.length];
|
||||||
for (int i = 0; i < points.size(); ++i) {
|
System.arraycopy(lowerEndsIn, 0 , this.lowerEnds, 0, lowerEndsIn.length);
|
||||||
Interval newInterval = new Interval(points.get(i).lower, points.get(i).higher);
|
System.arraycopy(higherEndsIn, 0 , this.higherEnds, 0, higherEndsIn.length);
|
||||||
this.points.add(newInterval);
|
lowerEndsIND = Nd4j.createFromArray(lowerEnds);
|
||||||
|
higherEndsIND = Nd4j.createFromArray(higherEnds);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public HyperRect(float[] point) {
|
||||||
|
this(point, point);
|
||||||
|
}
|
||||||
|
|
||||||
|
public HyperRect(Pair<float[], float[]> ends) {
|
||||||
|
this(ends.getFirst(), ends.getSecond());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public void enlargeTo(INDArray point) {
|
public void enlargeTo(INDArray point) {
|
||||||
for (int i = 0; i < points.size(); i++)
|
float[] pointAsArray = point.toFloatVector();
|
||||||
points.get(i).enlarge(point.getDouble(i));
|
for (int i = 0; i < lowerEnds.length; i++) {
|
||||||
|
float p = pointAsArray[i];
|
||||||
|
if (lowerEnds[i] > p)
|
||||||
|
lowerEnds[i] = p;
|
||||||
|
else if (higherEnds[i] < p)
|
||||||
|
higherEnds[i] = p;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static Pair<float[],float[]> point(INDArray vector) {
|
||||||
public static List<Interval> point(INDArray vector) {
|
Pair<float[],float[]> ret = new Pair<>();
|
||||||
List<Interval> ret = new ArrayList<>();
|
float[] curr = new float[(int)vector.length()];
|
||||||
for (int i = 0; i < vector.length(); i++) {
|
for (int i = 0; i < vector.length(); i++) {
|
||||||
double curr = vector.getDouble(i);
|
curr[i] = vector.getFloat(i);
|
||||||
ret.add(new Interval(curr, curr));
|
|
||||||
}
|
}
|
||||||
|
ret.setFirst(curr);
|
||||||
|
ret.setSecond(curr);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public List<Boolean> contains(INDArray hPoint) {
|
/*public List<Boolean> contains(INDArray hPoint) {
|
||||||
List<Boolean> ret = new ArrayList<>();
|
List<Boolean> ret = new ArrayList<>();
|
||||||
for (int i = 0; i < hPoint.length(); i++)
|
|
||||||
ret.add(points.get(i).contains(hPoint.getDouble(i)));
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
public double minDistance(INDArray hPoint) {
|
|
||||||
double ret = 0.0;
|
|
||||||
for (int i = 0; i < hPoint.length(); i++) {
|
for (int i = 0; i < hPoint.length(); i++) {
|
||||||
double p = hPoint.getDouble(i);
|
ret.add(lowerEnds[i] <= hPoint.getDouble(i) &&
|
||||||
Interval interval = points.get(i);
|
higherEnds[i] >= hPoint.getDouble(i));
|
||||||
if (!interval.contains(p)) {
|
|
||||||
if (p < interval.lower)
|
|
||||||
ret += Math.pow((p - interval.lower), 2);
|
|
||||||
else
|
|
||||||
ret += Math.pow((p - interval.higher), 2);
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
ret = Math.pow(ret, 0.5);
|
|
||||||
return ret;
|
return ret;
|
||||||
|
}*/
|
||||||
|
|
||||||
|
public double minDistance(INDArray hPoint, INDArray output) {
|
||||||
|
Nd4j.exec(new KnnMinDistance(hPoint, lowerEndsIND, higherEndsIND, output));
|
||||||
|
return output.getFloat(0);
|
||||||
|
|
||||||
|
/*double ret = 0.0;
|
||||||
|
double[] pointAsArray = hPoint.toDoubleVector();
|
||||||
|
for (int i = 0; i < pointAsArray.length; i++) {
|
||||||
|
double p = pointAsArray[i];
|
||||||
|
if (!(lowerEnds[i] <= p || higherEnds[i] <= p)) {
|
||||||
|
if (p < lowerEnds[i])
|
||||||
|
ret += Math.pow((p - lowerEnds[i]), 2);
|
||||||
|
else
|
||||||
|
ret += Math.pow((p - higherEnds[i]), 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ret = Math.pow(ret, 0.5);
|
||||||
|
return ret;*/
|
||||||
}
|
}
|
||||||
|
|
||||||
public HyperRect getUpper(INDArray hPoint, int desc) {
|
public HyperRect getUpper(INDArray hPoint, int desc) {
|
||||||
Interval interval = points.get(desc);
|
//Interval interval = points.get(desc);
|
||||||
double d = hPoint.getDouble(desc);
|
float higher = higherEnds[desc];
|
||||||
if (interval.higher < d)
|
float d = hPoint.getFloat(desc);
|
||||||
|
if (higher < d)
|
||||||
return null;
|
return null;
|
||||||
HyperRect ret = new HyperRect(new ArrayList<>(points));
|
HyperRect ret = new HyperRect(lowerEnds,higherEnds);
|
||||||
Interval i2 = ret.points.get(desc);
|
if (ret.lowerEnds[desc] < d)
|
||||||
if (i2.lower < d)
|
ret.lowerEnds[desc] = d;
|
||||||
i2.lower = d;
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
public HyperRect getLower(INDArray hPoint, int desc) {
|
public HyperRect getLower(INDArray hPoint, int desc) {
|
||||||
Interval interval = points.get(desc);
|
//Interval interval = points.get(desc);
|
||||||
double d = hPoint.getDouble(desc);
|
float lower = lowerEnds[desc];
|
||||||
if (interval.lower > d)
|
float d = hPoint.getFloat(desc);
|
||||||
|
if (lower > d)
|
||||||
return null;
|
return null;
|
||||||
HyperRect ret = new HyperRect(new ArrayList<>(points));
|
HyperRect ret = new HyperRect(lowerEnds,higherEnds);
|
||||||
Interval i2 = ret.points.get(desc);
|
//Interval i2 = ret.points.get(desc);
|
||||||
if (i2.higher > d)
|
if (ret.higherEnds[desc] > d)
|
||||||
i2.higher = d;
|
ret.higherEnds[desc] = d;
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -108,33 +135,10 @@ public class HyperRect implements Serializable {
|
||||||
public String toString() {
|
public String toString() {
|
||||||
String retVal = "";
|
String retVal = "";
|
||||||
retVal += "[";
|
retVal += "[";
|
||||||
for (val point : points) {
|
for (int i = 0; i < lowerEnds.length; ++i) {
|
||||||
retVal += "(" + point.lower + " - " + point.higher + ") ";
|
retVal += "(" + lowerEnds[i] + " - " + higherEnds[i] + ") ";
|
||||||
}
|
}
|
||||||
retVal += "]";
|
retVal += "]";
|
||||||
return retVal;
|
return retVal;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class Interval {
|
|
||||||
private double lower, higher;
|
|
||||||
|
|
||||||
public Interval(double lower, double higher) {
|
|
||||||
this.lower = lower;
|
|
||||||
this.higher = higher;
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean contains(double point) {
|
|
||||||
return lower <= point || point <= higher;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public void enlarge(double p) {
|
|
||||||
if (lower > p)
|
|
||||||
lower = p;
|
|
||||||
else if (higher < p)
|
|
||||||
higher = p;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -56,7 +56,7 @@ public class KDTree implements Serializable {
|
||||||
|
|
||||||
if (root == null) {
|
if (root == null) {
|
||||||
root = new KDNode(point);
|
root = new KDNode(point);
|
||||||
rect = new HyperRect(HyperRect.point(point));
|
rect = new HyperRect(/*HyperRect.point(point)*/ point.toFloatVector());
|
||||||
} else {
|
} else {
|
||||||
int disc = 0;
|
int disc = 0;
|
||||||
KDNode node = root;
|
KDNode node = root;
|
||||||
|
@ -125,15 +125,21 @@ public class KDTree implements Serializable {
|
||||||
return node.getPoint();
|
return node.getPoint();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Share this data for recursive calls of "knn"
|
||||||
|
private float currentDistance;
|
||||||
|
private INDArray currentPoint;
|
||||||
|
private INDArray minDistance = Nd4j.scalar(0.f);
|
||||||
|
|
||||||
|
|
||||||
public List<Pair<Double, INDArray>> knn(INDArray point, double distance) {
|
public List<Pair<Float, INDArray>> knn(INDArray point, float distance) {
|
||||||
List<Pair<Double, INDArray>> best = new ArrayList<>();
|
List<Pair<Float, INDArray>> best = new ArrayList<>();
|
||||||
knn(root, point, rect, distance, best, 0);
|
currentDistance = distance;
|
||||||
Collections.sort(best, new Comparator<Pair<Double, INDArray>>() {
|
currentPoint = point;
|
||||||
|
knn(root, rect, best, 0);
|
||||||
|
Collections.sort(best, new Comparator<Pair<Float, INDArray>>() {
|
||||||
@Override
|
@Override
|
||||||
public int compare(Pair<Double, INDArray> o1, Pair<Double, INDArray> o2) {
|
public int compare(Pair<Float, INDArray> o1, Pair<Float, INDArray> o2) {
|
||||||
return Double.compare(o1.getKey(), o2.getKey());
|
return Float.compare(o1.getKey(), o2.getKey());
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -141,22 +147,21 @@ public class KDTree implements Serializable {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private void knn(KDNode node, INDArray point, HyperRect rect, double dist, List<Pair<Double, INDArray>> best,
|
private void knn(KDNode node, HyperRect rect, List<Pair<Float, INDArray>> best, int _disc) {
|
||||||
int _disc) {
|
if (node == null || rect == null || rect.minDistance(currentPoint, minDistance) > currentDistance)
|
||||||
if (node == null || rect == null || rect.minDistance(point) > dist)
|
|
||||||
return;
|
return;
|
||||||
int _discNext = (_disc + 1) % dims;
|
int _discNext = (_disc + 1) % dims;
|
||||||
double distance = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(point,node.point)).getFinalResult()
|
float distance = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(currentPoint,node.point, minDistance)).getFinalResult()
|
||||||
.doubleValue();
|
.floatValue();
|
||||||
|
|
||||||
if (distance <= dist) {
|
if (distance <= currentDistance) {
|
||||||
best.add(Pair.of(distance, node.getPoint()));
|
best.add(Pair.of(distance, node.getPoint()));
|
||||||
}
|
}
|
||||||
|
|
||||||
HyperRect lower = rect.getLower(node.point, _disc);
|
HyperRect lower = rect.getLower(node.point, _disc);
|
||||||
HyperRect upper = rect.getUpper(node.point, _disc);
|
HyperRect upper = rect.getUpper(node.point, _disc);
|
||||||
knn(node.getLeft(), point, lower, dist, best, _discNext);
|
knn(node.getLeft(), lower, best, _discNext);
|
||||||
knn(node.getRight(), point, upper, dist, best, _discNext);
|
knn(node.getRight(), upper, best, _discNext);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -171,7 +176,7 @@ public class KDTree implements Serializable {
|
||||||
|
|
||||||
private Pair<Double, INDArray> nn(KDNode node, INDArray point, HyperRect rect, double dist, INDArray best,
|
private Pair<Double, INDArray> nn(KDNode node, INDArray point, HyperRect rect, double dist, INDArray best,
|
||||||
int _disc) {
|
int _disc) {
|
||||||
if (node == null || rect.minDistance(point) > dist)
|
if (node == null || rect.minDistance(point, minDistance) > dist)
|
||||||
return Pair.of(Double.POSITIVE_INFINITY, null);
|
return Pair.of(Double.POSITIVE_INFINITY, null);
|
||||||
|
|
||||||
int _discNext = (_disc + 1) % dims;
|
int _discNext = (_disc + 1) % dims;
|
||||||
|
|
|
@ -16,6 +16,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.kdtree;
|
package org.deeplearning4j.clustering.kdtree;
|
||||||
|
|
||||||
|
import org.joda.time.Instant;
|
||||||
|
import org.nd4j.shade.guava.base.Stopwatch;
|
||||||
import org.nd4j.shade.guava.primitives.Doubles;
|
import org.nd4j.shade.guava.primitives.Doubles;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.deeplearning4j.clustering.BaseDL4JTest;
|
import org.deeplearning4j.clustering.BaseDL4JTest;
|
||||||
|
@ -28,6 +30,7 @@ import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
import org.nd4j.shade.guava.primitives.Floats;
|
||||||
import org.opencv.ml.KNearest;
|
import org.opencv.ml.KNearest;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -35,6 +38,8 @@ import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
|
import static java.util.concurrent.TimeUnit.MILLISECONDS;
|
||||||
|
import static java.util.concurrent.TimeUnit.SECONDS;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
|
@ -53,17 +58,17 @@ public class KDTreeTest extends BaseDL4JTest {
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
kdTree = new KDTree(2);
|
kdTree = new KDTree(2);
|
||||||
double[] data = new double[]{7,2};
|
float[] data = new float[]{7,2};
|
||||||
kdTree.insert(Nd4j.createFromArray(data));
|
kdTree.insert(Nd4j.createFromArray(data));
|
||||||
data = new double[]{5,4};
|
data = new float[]{5,4};
|
||||||
kdTree.insert(Nd4j.createFromArray(data));
|
kdTree.insert(Nd4j.createFromArray(data));
|
||||||
data = new double[]{2,3};
|
data = new float[]{2,3};
|
||||||
kdTree.insert(Nd4j.createFromArray(data));
|
kdTree.insert(Nd4j.createFromArray(data));
|
||||||
data = new double[]{4,7};
|
data = new float[]{4,7};
|
||||||
kdTree.insert(Nd4j.createFromArray(data));
|
kdTree.insert(Nd4j.createFromArray(data));
|
||||||
data = new double[]{9,6};
|
data = new float[]{9,6};
|
||||||
kdTree.insert(Nd4j.createFromArray(data));
|
kdTree.insert(Nd4j.createFromArray(data));
|
||||||
data = new double[]{8,1};
|
data = new float[]{8,1};
|
||||||
kdTree.insert(Nd4j.createFromArray(data));
|
kdTree.insert(Nd4j.createFromArray(data));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -168,26 +173,30 @@ public class KDTreeTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testKNN() {
|
public void testKNN() {
|
||||||
int n = 10;
|
int dimensions = 512;
|
||||||
// make a KD-tree of dimension {#n}
|
int vectorsNo = 50000;
|
||||||
KDTree kdTree = new KDTree(n);
|
// make a KD-tree of dimension {#dimensions}
|
||||||
for (int i = -1; i < n; i++) {
|
Stopwatch stopwatch = Stopwatch.createStarted();
|
||||||
|
KDTree kdTree = new KDTree(dimensions);
|
||||||
|
for (int i = -1; i < vectorsNo; i++) {
|
||||||
// Insert a unit vector along each dimension
|
// Insert a unit vector along each dimension
|
||||||
List<Double> vec = new ArrayList<>(n);
|
INDArray indVec = Nd4j.rand(DataType.FLOAT, 1,dimensions);
|
||||||
// i = -1 ensures the origin is in the Tree
|
|
||||||
for (int k = 0; k < n; k++) {
|
|
||||||
vec.add((k == i) ? 1.0 : 0.0);
|
|
||||||
}
|
|
||||||
INDArray indVec = Nd4j.create(Nd4j.createBuffer(Doubles.toArray(vec)));
|
|
||||||
kdTree.insert(indVec);
|
kdTree.insert(indVec);
|
||||||
}
|
}
|
||||||
|
stopwatch.stop();
|
||||||
|
System.out.println("Time elapsed for " + kdTree.size() + " nodes construction is "+ stopwatch.elapsed(SECONDS));
|
||||||
|
|
||||||
Random rand = new Random();
|
Random rand = new Random();
|
||||||
// random point in the Hypercube
|
// random point in the Hypercube
|
||||||
List<Double> pt = new ArrayList(n);
|
List<Double> pt = new ArrayList(dimensions);
|
||||||
for (int k = 0; k < n; k++) {
|
for (int k = 0; k < dimensions; k++) {
|
||||||
pt.add(rand.nextDouble() * 10.0);
|
pt.add(rand.nextFloat() * 10.0);
|
||||||
}
|
}
|
||||||
List<Pair<Double, INDArray>> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Doubles.toArray(pt))), 20.0);
|
stopwatch.reset();
|
||||||
|
stopwatch.start();
|
||||||
|
List<Pair<Float, INDArray>> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Floats.toArray(pt))), 20.0f);
|
||||||
|
stopwatch.stop();
|
||||||
|
System.out.println("Time elapsed for Search is "+ stopwatch.elapsed(MILLISECONDS));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -195,15 +204,15 @@ public class KDTreeTest extends BaseDL4JTest {
|
||||||
int n = 2;
|
int n = 2;
|
||||||
KDTree kdTree = new KDTree(n);
|
KDTree kdTree = new KDTree(n);
|
||||||
|
|
||||||
double[] data = new double[]{3,3};
|
float[] data = new float[]{3,3};
|
||||||
kdTree.insert(Nd4j.createFromArray(data));
|
kdTree.insert(Nd4j.createFromArray(data));
|
||||||
data = new double[]{1,1};
|
data = new float[]{1,1};
|
||||||
kdTree.insert(Nd4j.createFromArray(data));
|
kdTree.insert(Nd4j.createFromArray(data));
|
||||||
data = new double[]{2,2};
|
data = new float[]{2,2};
|
||||||
kdTree.insert(Nd4j.createFromArray(data));
|
kdTree.insert(Nd4j.createFromArray(data));
|
||||||
|
|
||||||
data = new double[]{0,0};
|
data = new float[]{0,0};
|
||||||
List<Pair<Double, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 4.5);
|
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 4.5f);
|
||||||
|
|
||||||
assertEquals(1.0, result.get(0).getSecond().getDouble(0), 1e-5);
|
assertEquals(1.0, result.get(0).getSecond().getDouble(0), 1e-5);
|
||||||
assertEquals(1.0, result.get(0).getSecond().getDouble(1), 1e-5);
|
assertEquals(1.0, result.get(0).getSecond().getDouble(1), 1e-5);
|
||||||
|
@ -220,88 +229,88 @@ public class KDTreeTest extends BaseDL4JTest {
|
||||||
|
|
||||||
assertEquals(6, kdTree.size());
|
assertEquals(6, kdTree.size());
|
||||||
|
|
||||||
double[] data = new double[]{8,1};
|
float[] data = new float[]{8,1};
|
||||||
List<Pair<Double, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 10.0);
|
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f);
|
||||||
assertEquals(8.0, result.get(0).getSecond().getDouble(0), 1e-5);
|
assertEquals(8.0, result.get(0).getSecond().getFloat(0), 1e-5);
|
||||||
assertEquals(1.0, result.get(0).getSecond().getDouble(1), 1e-5);
|
assertEquals(1.0, result.get(0).getSecond().getFloat(1), 1e-5);
|
||||||
assertEquals(7.0, result.get(1).getSecond().getDouble(0), 1e-5);
|
assertEquals(7.0, result.get(1).getSecond().getFloat(0), 1e-5);
|
||||||
assertEquals(2.0, result.get(1).getSecond().getDouble(1), 1e-5);
|
assertEquals(2.0, result.get(1).getSecond().getFloat(1), 1e-5);
|
||||||
assertEquals(5.0, result.get(2).getSecond().getDouble(0), 1e-5);
|
assertEquals(5.0, result.get(2).getSecond().getFloat(0), 1e-5);
|
||||||
assertEquals(4.0, result.get(2).getSecond().getDouble(1), 1e-5);
|
assertEquals(4.0, result.get(2).getSecond().getFloat(1), 1e-5);
|
||||||
assertEquals(9.0, result.get(3).getSecond().getDouble(0), 1e-5);
|
assertEquals(9.0, result.get(3).getSecond().getFloat(0), 1e-5);
|
||||||
assertEquals(6.0, result.get(3).getSecond().getDouble(1), 1e-5);
|
assertEquals(6.0, result.get(3).getSecond().getFloat(1), 1e-5);
|
||||||
assertEquals(2.0, result.get(4).getSecond().getDouble(0), 1e-5);
|
assertEquals(2.0, result.get(4).getSecond().getFloat(0), 1e-5);
|
||||||
assertEquals(3.0, result.get(4).getSecond().getDouble(1), 1e-5);
|
assertEquals(3.0, result.get(4).getSecond().getFloat(1), 1e-5);
|
||||||
assertEquals(4.0, result.get(5).getSecond().getDouble(0), 1e-5);
|
assertEquals(4.0, result.get(5).getSecond().getFloat(0), 1e-5);
|
||||||
assertEquals(7.0, result.get(5).getSecond().getDouble(1), 1e-5);
|
assertEquals(7.0, result.get(5).getSecond().getFloat(1), 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testKNN_2() {
|
public void testKNN_2() {
|
||||||
double[] data = new double[]{8, 1};
|
float[] data = new float[]{8, 1};
|
||||||
List<Pair<Double, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 5.0);
|
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f);
|
||||||
assertEquals(8.0, result.get(0).getSecond().getDouble(0), 1e-5);
|
assertEquals(8.0, result.get(0).getSecond().getFloat(0), 1e-5);
|
||||||
assertEquals(1.0, result.get(0).getSecond().getDouble(1), 1e-5);
|
assertEquals(1.0, result.get(0).getSecond().getFloat(1), 1e-5);
|
||||||
assertEquals(7.0, result.get(1).getSecond().getDouble(0), 1e-5);
|
assertEquals(7.0, result.get(1).getSecond().getFloat(0), 1e-5);
|
||||||
assertEquals(2.0, result.get(1).getSecond().getDouble(1), 1e-5);
|
assertEquals(2.0, result.get(1).getSecond().getFloat(1), 1e-5);
|
||||||
assertEquals(5.0, result.get(2).getSecond().getDouble(0), 1e-5);
|
assertEquals(5.0, result.get(2).getSecond().getFloat(0), 1e-5);
|
||||||
assertEquals(4.0, result.get(2).getSecond().getDouble(1), 1e-5);
|
assertEquals(4.0, result.get(2).getSecond().getFloat(1), 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testKNN_3() {
|
public void testKNN_3() {
|
||||||
|
|
||||||
double[] data = new double[]{2, 3};
|
float[] data = new float[]{2, 3};
|
||||||
val result = kdTree.knn(Nd4j.createFromArray(data), 10.0);
|
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f);
|
||||||
assertEquals(2.0, result.get(0).getSecond().getDouble(0), 1e-5);
|
assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5);
|
||||||
assertEquals(3.0, result.get(0).getSecond().getDouble(1), 1e-5);
|
assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5);
|
||||||
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5);
|
assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5);
|
||||||
assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5);
|
assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5);
|
||||||
assertEquals(4.0, result.get(2).getSecond().getDouble(0), 1e-5);
|
assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5);
|
||||||
assertEquals(7.0, result.get(2).getSecond().getDouble(1), 1e-5);
|
assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5);
|
||||||
assertEquals(7.0, result.get(3).getSecond().getDouble(0), 1e-5);
|
assertEquals(7.0, result.get(3).getSecond().getFloat(0), 1e-5);
|
||||||
assertEquals(2.0, result.get(3).getSecond().getDouble(1), 1e-5);
|
assertEquals(2.0, result.get(3).getSecond().getFloat(1), 1e-5);
|
||||||
assertEquals(8.0, result.get(4).getSecond().getDouble(0), 1e-5);
|
assertEquals(8.0, result.get(4).getSecond().getFloat(0), 1e-5);
|
||||||
assertEquals(1.0, result.get(4).getSecond().getDouble(1), 1e-5);
|
assertEquals(1.0, result.get(4).getSecond().getFloat(1), 1e-5);
|
||||||
assertEquals(9.0, result.get(5).getSecond().getDouble(0), 1e-5);
|
assertEquals(9.0, result.get(5).getSecond().getFloat(0), 1e-5);
|
||||||
assertEquals(6.0, result.get(5).getSecond().getDouble(1), 1e-5);
|
assertEquals(6.0, result.get(5).getSecond().getFloat(1), 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testKNN_4() {
|
public void testKNN_4() {
|
||||||
double[] data = new double[]{2, 3};
|
float[] data = new float[]{2, 3};
|
||||||
val result = kdTree.knn(Nd4j.createFromArray(data), 5.0);
|
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f);
|
||||||
assertEquals(2.0, result.get(0).getSecond().getDouble(0), 1e-5);
|
assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5);
|
||||||
assertEquals(3.0, result.get(0).getSecond().getDouble(1), 1e-5);
|
assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5);
|
||||||
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5);
|
assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5);
|
||||||
assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5);
|
assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5);
|
||||||
assertEquals(4.0, result.get(2).getSecond().getDouble(0), 1e-5);
|
assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5);
|
||||||
assertEquals(7.0, result.get(2).getSecond().getDouble(1), 1e-5);
|
assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testKNN_5() {
|
public void testKNN_5() {
|
||||||
double[] data = new double[]{2, 3};
|
float[] data = new float[]{2, 3};
|
||||||
val result = kdTree.knn(Nd4j.createFromArray(data), 20.0);
|
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 20.0f);
|
||||||
assertEquals(2.0, result.get(0).getSecond().getDouble(0), 1e-5);
|
assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5);
|
||||||
assertEquals(3.0, result.get(0).getSecond().getDouble(1), 1e-5);
|
assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5);
|
||||||
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5);
|
assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5);
|
||||||
assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5);
|
assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5);
|
||||||
assertEquals(4.0, result.get(2).getSecond().getDouble(0), 1e-5);
|
assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5);
|
||||||
assertEquals(7.0, result.get(2).getSecond().getDouble(1), 1e-5);
|
assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5);
|
||||||
assertEquals(7.0, result.get(3).getSecond().getDouble(0), 1e-5);
|
assertEquals(7.0, result.get(3).getSecond().getFloat(0), 1e-5);
|
||||||
assertEquals(2.0, result.get(3).getSecond().getDouble(1), 1e-5);
|
assertEquals(2.0, result.get(3).getSecond().getFloat(1), 1e-5);
|
||||||
assertEquals(8.0, result.get(4).getSecond().getDouble(0), 1e-5);
|
assertEquals(8.0, result.get(4).getSecond().getFloat(0), 1e-5);
|
||||||
assertEquals(1.0, result.get(4).getSecond().getDouble(1), 1e-5);
|
assertEquals(1.0, result.get(4).getSecond().getFloat(1), 1e-5);
|
||||||
assertEquals(9.0, result.get(5).getSecond().getDouble(0), 1e-5);
|
assertEquals(9.0, result.get(5).getSecond().getFloat(0), 1e-5);
|
||||||
assertEquals(6.0, result.get(5).getSecond().getDouble(1), 1e-5);
|
assertEquals(6.0, result.get(5).getSecond().getFloat(1), 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void test_KNN_6() {
|
public void test_KNN_6() {
|
||||||
double[] data = new double[]{4, 6};
|
float[] data = new float[]{4, 6};
|
||||||
val result = kdTree.knn(Nd4j.createFromArray(data), 10.0);
|
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f);
|
||||||
assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5);
|
assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5);
|
||||||
assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5);
|
assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5);
|
||||||
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5);
|
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5);
|
||||||
|
@ -318,8 +327,8 @@ public class KDTreeTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void test_KNN_7() {
|
public void test_KNN_7() {
|
||||||
double[] data = new double[]{4, 6};
|
float[] data = new float[]{4, 6};
|
||||||
val result = kdTree.knn(Nd4j.createFromArray(data), 5.0);
|
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f);
|
||||||
assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5);
|
assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5);
|
||||||
assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5);
|
assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5);
|
||||||
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5);
|
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5);
|
||||||
|
@ -334,8 +343,8 @@ public class KDTreeTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void test_KNN_8() {
|
public void test_KNN_8() {
|
||||||
double[] data = new double[]{4, 6};
|
float[] data = new float[]{4, 6};
|
||||||
val result = kdTree.knn(Nd4j.createFromArray(data), 20.0);
|
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 20.0f);
|
||||||
assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5);
|
assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5);
|
||||||
assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5);
|
assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5);
|
||||||
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5);
|
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5);
|
||||||
|
@ -392,12 +401,12 @@ public class KDTreeTest extends BaseDL4JTest {
|
||||||
Duration duration = new Duration(start, end);
|
Duration duration = new Duration(start, end);
|
||||||
System.out.println("Elapsed time for tree construction " + duration.getStandardSeconds() + " " + duration.getMillis());
|
System.out.println("Elapsed time for tree construction " + duration.getStandardSeconds() + " " + duration.getMillis());
|
||||||
|
|
||||||
List<Double> pt = new ArrayList(num);
|
List<Float> pt = new ArrayList(num);
|
||||||
for (int k = 0; k < n; k++) {
|
for (int k = 0; k < n; k++) {
|
||||||
pt.add((double)(num / 2));
|
pt.add((float)(num / 2));
|
||||||
}
|
}
|
||||||
start = System.currentTimeMillis();
|
start = System.currentTimeMillis();
|
||||||
List<Pair<Double, INDArray>> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Doubles.toArray(pt))), 20.0);
|
List<Pair<Float, INDArray>> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Doubles.toArray(pt))), 20.0f);
|
||||||
end = System.currentTimeMillis();
|
end = System.currentTimeMillis();
|
||||||
duration = new Duration(start, end);
|
duration = new Duration(start, end);
|
||||||
long elapsed = end - start;
|
long elapsed = end - start;
|
||||||
|
|
|
@ -50,6 +50,7 @@ import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
@ -816,6 +817,37 @@ public class Word2VecTests extends BaseDL4JTest {
|
||||||
assertEquals(vec1.getWordVectorMatrix("money"), vec2.getWordVectorMatrix("money"));
|
assertEquals(vec1.getWordVectorMatrix("money"), vec2.getWordVectorMatrix("money"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testWordsNearestSum() throws IOException {
|
||||||
|
log.info("Load & Vectorize Sentences....");
|
||||||
|
SentenceIterator iter = new BasicLineIterator(inputFile);
|
||||||
|
TokenizerFactory t = new DefaultTokenizerFactory();
|
||||||
|
t.setTokenPreProcessor(new CommonPreprocessor());
|
||||||
|
|
||||||
|
log.info("Building model....");
|
||||||
|
Word2Vec vec = new Word2Vec.Builder()
|
||||||
|
.minWordFrequency(5)
|
||||||
|
.iterations(1)
|
||||||
|
.layerSize(100)
|
||||||
|
.seed(42)
|
||||||
|
.windowSize(5)
|
||||||
|
.iterate(iter)
|
||||||
|
.tokenizerFactory(t)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
log.info("Fitting Word2Vec model....");
|
||||||
|
vec.fit();
|
||||||
|
log.info("Writing word vectors to text file....");
|
||||||
|
log.info("Closest Words:");
|
||||||
|
Collection<String> lst = vec.wordsNearestSum("day", 10);
|
||||||
|
log.info("10 Words closest to 'day': {}", lst);
|
||||||
|
assertTrue(lst.contains("week"));
|
||||||
|
assertTrue(lst.contains("night"));
|
||||||
|
assertTrue(lst.contains("year"));
|
||||||
|
assertTrue(lst.contains("years"));
|
||||||
|
assertTrue(lst.contains("time"));
|
||||||
|
}
|
||||||
|
|
||||||
private static void printWords(String target, Collection<String> list, Word2Vec vec) {
|
private static void printWords(String target, Collection<String> list, Word2Vec vec) {
|
||||||
System.out.println("Words close to [" + target + "]:");
|
System.out.println("Words close to [" + target + "]:");
|
||||||
for (String word : list) {
|
for (String word : list) {
|
||||||
|
|
|
@ -104,7 +104,7 @@ public class InMemoryLookupTable<T extends SequenceElement> implements WeightLoo
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void initAdaGrad() {
|
protected void initAdaGrad() {
|
||||||
int[] shape = new int[] {vocab.numWords() + 1, vectorLength};
|
long[] shape = new long[] {vocab.numWords() + 1, vectorLength};
|
||||||
int length = ArrayUtil.prod(shape);
|
int length = ArrayUtil.prod(shape);
|
||||||
adaGrad = new AdaGrad(shape, lr.get());
|
adaGrad = new AdaGrad(shape, lr.get());
|
||||||
adaGrad.setStateViewArray(Nd4j.zeros(shape).reshape(1, length), shape, Nd4j.order(), true);
|
adaGrad.setStateViewArray(Nd4j.zeros(shape).reshape(1, length), shape, Nd4j.order(), true);
|
||||||
|
@ -124,8 +124,7 @@ public class InMemoryLookupTable<T extends SequenceElement> implements WeightLoo
|
||||||
if (adaGrad == null)
|
if (adaGrad == null)
|
||||||
initAdaGrad();
|
initAdaGrad();
|
||||||
|
|
||||||
// FIXME: int cast
|
return adaGrad.getGradient(gradient, column, syn0.shape());
|
||||||
return adaGrad.getGradient(gradient, column, ArrayUtil.toInts(syn0.shape()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -370,7 +369,6 @@ public class InMemoryLookupTable<T extends SequenceElement> implements WeightLoo
|
||||||
else {
|
else {
|
||||||
nextRandom.set(nextRandom.get() * 25214903917L + 11);
|
nextRandom.set(nextRandom.get() * 25214903917L + 11);
|
||||||
|
|
||||||
// FIXME: int cast
|
|
||||||
int idx = (int) Math.abs((int) (nextRandom.get() >> 16) % table.length());
|
int idx = (int) Math.abs((int) (nextRandom.get() >> 16) % table.length());
|
||||||
|
|
||||||
target = table.getInt(idx);
|
target = table.getInt(idx);
|
||||||
|
|
|
@ -33,7 +33,6 @@ import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.aggregates.Aggregate;
|
import org.nd4j.linalg.api.ops.aggregates.Aggregate;
|
||||||
import org.nd4j.linalg.api.ops.aggregates.impl.AggregateCBOW;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.nlp.CbowRound;
|
import org.nd4j.linalg.api.ops.impl.nlp.CbowRound;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.util.DeviceLocalNDArray;
|
import org.nd4j.linalg.util.DeviceLocalNDArray;
|
||||||
|
|
|
@ -104,11 +104,10 @@ public class GloVe<T extends SequenceElement> implements ElementsLearningAlgorit
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
weightAdaGrad = new AdaGrad(new int[] {this.vocabCache.numWords() + 1, vectorLength}, learningRate);
|
weightAdaGrad = new AdaGrad(new long[] {this.vocabCache.numWords() + 1, vectorLength}, learningRate);
|
||||||
bias = Nd4j.create(syn0.rows());
|
bias = Nd4j.create(syn0.rows());
|
||||||
|
|
||||||
// FIXME: int cast
|
biasAdaGrad = new AdaGrad(bias.shape(), this.learningRate);
|
||||||
biasAdaGrad = new AdaGrad(ArrayUtil.toInts(bias.shape()), this.learningRate);
|
|
||||||
|
|
||||||
// maxmemory = Runtime.getRuntime().maxMemory() - (vocabCache.numWords() * vectorLength * 2 * 8);
|
// maxmemory = Runtime.getRuntime().maxMemory() - (vocabCache.numWords() * vectorLength * 2 * 8);
|
||||||
|
|
||||||
|
@ -237,15 +236,13 @@ public class GloVe<T extends SequenceElement> implements ElementsLearningAlgorit
|
||||||
private void update(T element1, INDArray wordVector, INDArray contextVector, double gradient) {
|
private void update(T element1, INDArray wordVector, INDArray contextVector, double gradient) {
|
||||||
//gradient for word vectors
|
//gradient for word vectors
|
||||||
INDArray grad1 = contextVector.mul(gradient);
|
INDArray grad1 = contextVector.mul(gradient);
|
||||||
// FIXME: int cast
|
INDArray update = weightAdaGrad.getGradient(grad1, element1.getIndex(), syn0.shape());
|
||||||
INDArray update = weightAdaGrad.getGradient(grad1, element1.getIndex(), ArrayUtil.toInts(syn0.shape()));
|
|
||||||
|
|
||||||
//update vector
|
//update vector
|
||||||
wordVector.subi(update);
|
wordVector.subi(update);
|
||||||
|
|
||||||
double w1Bias = bias.getDouble(element1.getIndex());
|
double w1Bias = bias.getDouble(element1.getIndex());
|
||||||
// FIXME: int cast
|
double biasGradient = biasAdaGrad.getGradient(gradient, element1.getIndex(), bias.shape());
|
||||||
double biasGradient = biasAdaGrad.getGradient(gradient, element1.getIndex(), ArrayUtil.toInts(bias.shape()));
|
|
||||||
double update2 = w1Bias - biasGradient;
|
double update2 = w1Bias - biasGradient;
|
||||||
bias.putScalar(element1.getIndex(), update2);
|
bias.putScalar(element1.getIndex(), update2);
|
||||||
}
|
}
|
||||||
|
|
|
@ -351,13 +351,13 @@ public class BasicModelUtils<T extends SequenceElement> implements ModelUtils<T>
|
||||||
if (lookupTable instanceof InMemoryLookupTable) {
|
if (lookupTable instanceof InMemoryLookupTable) {
|
||||||
InMemoryLookupTable l = (InMemoryLookupTable) lookupTable;
|
InMemoryLookupTable l = (InMemoryLookupTable) lookupTable;
|
||||||
INDArray syn0 = l.getSyn0();
|
INDArray syn0 = l.getSyn0();
|
||||||
INDArray weights = syn0.norm2(0).rdivi(1).muli(words);
|
INDArray temp = syn0.norm2(0).rdivi(1).reshape(words.shape());
|
||||||
|
INDArray weights = temp.muli(words);
|
||||||
INDArray distances = syn0.mulRowVector(weights).sum(1);
|
INDArray distances = syn0.mulRowVector(weights).sum(1);
|
||||||
INDArray[] sorted = Nd4j.sortWithIndices(distances, 0, false);
|
INDArray[] sorted = Nd4j.sortWithIndices(distances, 0, false);
|
||||||
INDArray sort = sorted[0];
|
INDArray sort = sorted[0];
|
||||||
List<String> ret = new ArrayList<>();
|
List<String> ret = new ArrayList<>();
|
||||||
|
|
||||||
// FIXME: int cast
|
|
||||||
if (top > sort.length())
|
if (top > sort.length())
|
||||||
top = (int) sort.length();
|
top = (int) sort.length();
|
||||||
//there will be a redundant word
|
//there will be a redundant word
|
||||||
|
|
|
@ -72,7 +72,7 @@ public class GloveWeightLookupTable<T extends SequenceElement> extends InMemoryL
|
||||||
putVector(Word2Vec.DEFAULT_UNK, randUnk);
|
putVector(Word2Vec.DEFAULT_UNK, randUnk);
|
||||||
}
|
}
|
||||||
if (weightAdaGrad == null || reset) {
|
if (weightAdaGrad == null || reset) {
|
||||||
weightAdaGrad = new AdaGrad(new int[] {vocab.numWords() + 1, vectorLength}, lr.get());
|
weightAdaGrad = new AdaGrad(new long[]{vocab.numWords() + 1, vectorLength}, lr.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -81,7 +81,7 @@ public class GloveWeightLookupTable<T extends SequenceElement> extends InMemoryL
|
||||||
bias = Nd4j.create(syn0.rows());
|
bias = Nd4j.create(syn0.rows());
|
||||||
|
|
||||||
if (biasAdaGrad == null || reset) {
|
if (biasAdaGrad == null || reset) {
|
||||||
biasAdaGrad = new AdaGrad(ArrayUtil.toInts(bias.shape()), lr.get());
|
biasAdaGrad = new AdaGrad(bias.shape(), lr.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -140,13 +140,13 @@ public class GloveWeightLookupTable<T extends SequenceElement> extends InMemoryL
|
||||||
private void update(T w1, INDArray wordVector, INDArray contextVector, double gradient) {
|
private void update(T w1, INDArray wordVector, INDArray contextVector, double gradient) {
|
||||||
//gradient for word vectors
|
//gradient for word vectors
|
||||||
INDArray grad1 = contextVector.mul(gradient);
|
INDArray grad1 = contextVector.mul(gradient);
|
||||||
INDArray update = weightAdaGrad.getGradient(grad1, w1.getIndex(), ArrayUtil.toInts(syn0.shape()));
|
INDArray update = weightAdaGrad.getGradient(grad1, w1.getIndex(), syn0.shape());
|
||||||
|
|
||||||
//update vector
|
//update vector
|
||||||
wordVector.subi(update);
|
wordVector.subi(update);
|
||||||
|
|
||||||
double w1Bias = bias.getDouble(w1.getIndex());
|
double w1Bias = bias.getDouble(w1.getIndex());
|
||||||
double biasGradient = biasAdaGrad.getGradient(gradient, w1.getIndex(), ArrayUtil.toInts(bias.shape()));
|
double biasGradient = biasAdaGrad.getGradient(gradient, w1.getIndex(), bias.shape());
|
||||||
double update2 = w1Bias - biasGradient;
|
double update2 = w1Bias - biasGradient;
|
||||||
bias.putScalar(w1.getIndex(), update2);
|
bias.putScalar(w1.getIndex(), update2);
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.deeplearning4j.nn.api.Model;
|
import org.deeplearning4j.nn.api.Model;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.exception.ND4JArraySizeException;
|
||||||
import org.nd4j.linalg.function.Consumer;
|
import org.nd4j.linalg.function.Consumer;
|
||||||
import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT;
|
import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
@ -293,7 +294,8 @@ public class GradientCheckUtil {
|
||||||
ss = n;
|
ss = n;
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME: int cast
|
if (ss > Integer.MAX_VALUE)
|
||||||
|
throw new ND4JArraySizeException();
|
||||||
stepSizeForParam.put(paramNames.get(i), (int) ss);
|
stepSizeForParam.put(paramNames.get(i), (int) ss);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -140,10 +140,9 @@ public class ElementWiseVertex extends GraphVertex {
|
||||||
//CNN inputs... also check that the channels, width and heights match:
|
//CNN inputs... also check that the channels, width and heights match:
|
||||||
InputType.InputTypeConvolutional firstConv = (InputType.InputTypeConvolutional) first;
|
InputType.InputTypeConvolutional firstConv = (InputType.InputTypeConvolutional) first;
|
||||||
|
|
||||||
// FIXME: int cast
|
val fd = firstConv.getChannels();
|
||||||
val fd = (int) firstConv.getChannels();
|
val fw = firstConv.getWidth();
|
||||||
val fw = (int) firstConv.getWidth();
|
val fh = firstConv.getHeight();
|
||||||
val fh = (int) firstConv.getHeight();
|
|
||||||
|
|
||||||
for (int i = 1; i < vertexInputs.length; i++) {
|
for (int i = 1; i < vertexInputs.length; i++) {
|
||||||
if (vertexInputs[i].getType() != InputType.Type.CNN) {
|
if (vertexInputs[i].getType() != InputType.Type.CNN) {
|
||||||
|
@ -155,10 +154,9 @@ public class ElementWiseVertex extends GraphVertex {
|
||||||
|
|
||||||
InputType.InputTypeConvolutional otherConv = (InputType.InputTypeConvolutional) vertexInputs[i];
|
InputType.InputTypeConvolutional otherConv = (InputType.InputTypeConvolutional) vertexInputs[i];
|
||||||
|
|
||||||
// FIXME: int cast
|
val od = otherConv.getChannels();
|
||||||
val od = (int) otherConv.getChannels();
|
val ow = otherConv.getWidth();
|
||||||
val ow = (int) otherConv.getWidth();
|
val oh = otherConv.getHeight();
|
||||||
val oh = (int) otherConv.getHeight();
|
|
||||||
|
|
||||||
if (fd != od || fw != ow || fh != oh) {
|
if (fd != od || fw != ow || fh != oh) {
|
||||||
throw new InvalidInputTypeException(
|
throw new InvalidInputTypeException(
|
||||||
|
|
|
@ -94,13 +94,12 @@ public class MergeVertex extends GraphVertex {
|
||||||
// CNN3D inputs: check that the channels, width and height match:
|
// CNN3D inputs: check that the channels, width and height match:
|
||||||
InputType.InputTypeConvolutional3D firstConv = (InputType.InputTypeConvolutional3D) first;
|
InputType.InputTypeConvolutional3D firstConv = (InputType.InputTypeConvolutional3D) first;
|
||||||
|
|
||||||
// FIXME: int cast
|
val fd = firstConv.getDepth();
|
||||||
val fd = (int) firstConv.getDepth();
|
val fw = firstConv.getWidth();
|
||||||
val fw = (int) firstConv.getWidth();
|
val fh = firstConv.getHeight();
|
||||||
val fh = (int) firstConv.getHeight();
|
val fc = firstConv.getChannels();
|
||||||
val fc = (int) firstConv.getChannels();
|
|
||||||
|
|
||||||
int depthSum = fc;
|
long depthSum = fc;
|
||||||
InputType.InputTypeConvolutional3D otherConv = null;
|
InputType.InputTypeConvolutional3D otherConv = null;
|
||||||
for (int i = 1; i < vertexInputs.length; i++) {
|
for (int i = 1; i < vertexInputs.length; i++) {
|
||||||
if (vertexInputs[i].getType() != InputType.Type.CNN3D) {
|
if (vertexInputs[i].getType() != InputType.Type.CNN3D) {
|
||||||
|
@ -109,10 +108,10 @@ public class MergeVertex extends GraphVertex {
|
||||||
}
|
}
|
||||||
|
|
||||||
otherConv = (InputType.InputTypeConvolutional3D) vertexInputs[i];
|
otherConv = (InputType.InputTypeConvolutional3D) vertexInputs[i];
|
||||||
val od = (int) otherConv.getDepth();
|
val od = otherConv.getDepth();
|
||||||
val ow = (int) otherConv.getWidth();
|
val ow = otherConv.getWidth();
|
||||||
val oh = (int) otherConv.getHeight();
|
val oh = otherConv.getHeight();
|
||||||
val oc = (int) otherConv.getChannels();
|
val oc = otherConv.getChannels();
|
||||||
|
|
||||||
if (fd != od || fw != ow || fh != oh) {
|
if (fd != od || fw != ow || fh != oh) {
|
||||||
throw new InvalidInputTypeException("Invalid input: MergeVertex cannot merge CNN3D activations of different width/heights:" + "first [channels,width,height] = [" + fd + "," + fw + "," + fh
|
throw new InvalidInputTypeException("Invalid input: MergeVertex cannot merge CNN3D activations of different width/heights:" + "first [channels,width,height] = [" + fd + "," + fw + "," + fh
|
||||||
|
@ -177,12 +176,11 @@ public class MergeVertex extends GraphVertex {
|
||||||
//CNN inputs... also check that the channels, width and heights match:
|
//CNN inputs... also check that the channels, width and heights match:
|
||||||
InputType.InputTypeConvolutional firstConv = (InputType.InputTypeConvolutional) first;
|
InputType.InputTypeConvolutional firstConv = (InputType.InputTypeConvolutional) first;
|
||||||
|
|
||||||
// FIXME: int cast
|
val fd = firstConv.getChannels();
|
||||||
val fd = (int) firstConv.getChannels();
|
val fw = firstConv.getWidth();
|
||||||
val fw = (int) firstConv.getWidth();
|
val fh = firstConv.getHeight();
|
||||||
val fh = (int) firstConv.getHeight();
|
|
||||||
|
|
||||||
int depthSum = fd;
|
long depthSum = fd;
|
||||||
|
|
||||||
for (int i = 1; i < vertexInputs.length; i++) {
|
for (int i = 1; i < vertexInputs.length; i++) {
|
||||||
if (vertexInputs[i].getType() != InputType.Type.CNN) {
|
if (vertexInputs[i].getType() != InputType.Type.CNN) {
|
||||||
|
@ -194,10 +192,9 @@ public class MergeVertex extends GraphVertex {
|
||||||
|
|
||||||
InputType.InputTypeConvolutional otherConv = (InputType.InputTypeConvolutional) vertexInputs[i];
|
InputType.InputTypeConvolutional otherConv = (InputType.InputTypeConvolutional) vertexInputs[i];
|
||||||
|
|
||||||
// FIXME: int cast
|
val od = otherConv.getChannels();
|
||||||
val od = (int) otherConv.getChannels();
|
val ow = otherConv.getWidth();
|
||||||
val ow = (int) otherConv.getWidth();
|
val oh = otherConv.getHeight();
|
||||||
val oh = (int) otherConv.getHeight();
|
|
||||||
|
|
||||||
if (fw != ow || fh != oh) {
|
if (fw != ow || fh != oh) {
|
||||||
throw new InvalidInputTypeException(
|
throw new InvalidInputTypeException(
|
||||||
|
|
|
@ -131,12 +131,11 @@ public class PoolHelperVertex extends GraphVertex {
|
||||||
//CNN inputs... also check that the channels, width and heights match:
|
//CNN inputs... also check that the channels, width and heights match:
|
||||||
InputType.InputTypeConvolutional firstConv = (InputType.InputTypeConvolutional) first;
|
InputType.InputTypeConvolutional firstConv = (InputType.InputTypeConvolutional) first;
|
||||||
|
|
||||||
// FIXME: int cast
|
val fd = firstConv.getChannels();
|
||||||
val fd = (int) firstConv.getChannels();
|
val fw = firstConv.getWidth();
|
||||||
val fw = (int) firstConv.getWidth();
|
val fh = firstConv.getHeight();
|
||||||
val fh = (int) firstConv.getHeight();
|
|
||||||
|
|
||||||
int depthSum = fd;
|
long depthSum = fd;
|
||||||
|
|
||||||
for (int i = 1; i < vertexInputs.length; i++) {
|
for (int i = 1; i < vertexInputs.length; i++) {
|
||||||
if (vertexInputs[i].getType() != InputType.Type.CNN) {
|
if (vertexInputs[i].getType() != InputType.Type.CNN) {
|
||||||
|
@ -148,10 +147,9 @@ public class PoolHelperVertex extends GraphVertex {
|
||||||
|
|
||||||
InputType.InputTypeConvolutional otherConv = (InputType.InputTypeConvolutional) vertexInputs[i];
|
InputType.InputTypeConvolutional otherConv = (InputType.InputTypeConvolutional) vertexInputs[i];
|
||||||
|
|
||||||
// FIXME: int cast
|
long od = otherConv.getChannels();
|
||||||
int od = (int) otherConv.getChannels();
|
long ow = otherConv.getWidth();
|
||||||
int ow = (int) otherConv.getWidth();
|
long oh = otherConv.getHeight();
|
||||||
int oh = (int) otherConv.getHeight();
|
|
||||||
|
|
||||||
if (fw != ow || fh != oh) {
|
if (fw != ow || fh != oh) {
|
||||||
throw new InvalidInputTypeException(
|
throw new InvalidInputTypeException(
|
||||||
|
|
|
@ -150,12 +150,11 @@ public class UnstackVertex extends GraphVertex {
|
||||||
//CNN inputs... also check that the channels, width and heights match:
|
//CNN inputs... also check that the channels, width and heights match:
|
||||||
InputType.InputTypeConvolutional firstConv = (InputType.InputTypeConvolutional) first;
|
InputType.InputTypeConvolutional firstConv = (InputType.InputTypeConvolutional) first;
|
||||||
|
|
||||||
// FIXME: int cast
|
val fd = firstConv.getChannels();
|
||||||
val fd = (int) firstConv.getChannels();
|
val fw = firstConv.getWidth();
|
||||||
val fw = (int) firstConv.getWidth();
|
val fh = firstConv.getHeight();
|
||||||
val fh = (int) firstConv.getHeight();
|
|
||||||
|
|
||||||
int depthSum = fd;
|
long depthSum = fd;
|
||||||
|
|
||||||
for (int i = 1; i < vertexInputs.length; i++) {
|
for (int i = 1; i < vertexInputs.length; i++) {
|
||||||
if (vertexInputs[i].getType() != InputType.Type.CNN) {
|
if (vertexInputs[i].getType() != InputType.Type.CNN) {
|
||||||
|
@ -167,10 +166,9 @@ public class UnstackVertex extends GraphVertex {
|
||||||
|
|
||||||
InputType.InputTypeConvolutional otherConv = (InputType.InputTypeConvolutional) vertexInputs[i];
|
InputType.InputTypeConvolutional otherConv = (InputType.InputTypeConvolutional) vertexInputs[i];
|
||||||
|
|
||||||
// FIXME: int cast
|
val od = otherConv.getChannels();
|
||||||
val od = (int) otherConv.getChannels();
|
val ow = otherConv.getWidth();
|
||||||
val ow = (int) otherConv.getWidth();
|
val oh = otherConv.getHeight();
|
||||||
val oh = (int) otherConv.getHeight();
|
|
||||||
|
|
||||||
if (fw != ow || fh != oh) {
|
if (fw != ow || fh != oh) {
|
||||||
throw new InvalidInputTypeException(
|
throw new InvalidInputTypeException(
|
||||||
|
|
|
@ -402,18 +402,17 @@ public abstract class InputType implements Serializable {
|
||||||
//Note: ConvolutionalFlat and FeedForward look identical... but either should work OK if using something
|
//Note: ConvolutionalFlat and FeedForward look identical... but either should work OK if using something
|
||||||
// like FeedForwardToCnnPreProcessor
|
// like FeedForwardToCnnPreProcessor
|
||||||
|
|
||||||
// FIXME: int cast
|
|
||||||
switch (inputArray.rank()) {
|
switch (inputArray.rank()) {
|
||||||
case 2:
|
case 2:
|
||||||
return InputType.feedForward((int) inputArray.size(1));
|
return InputType.feedForward(inputArray.size(1));
|
||||||
case 3:
|
case 3:
|
||||||
return InputType.recurrent((int) inputArray.size(1), (int) inputArray.size(2));
|
return InputType.recurrent(inputArray.size(1), (int) inputArray.size(2));
|
||||||
case 4:
|
case 4:
|
||||||
//Order: [minibatch, channels, height, width] -> [h, w, c]
|
//Order: [minibatch, channels, height, width] -> [h, w, c]
|
||||||
return InputType.convolutional((int) inputArray.size(2), (int) inputArray.size(3), (int) inputArray.size(1));
|
return InputType.convolutional(inputArray.size(2), (int) inputArray.size(3), (int) inputArray.size(1));
|
||||||
case 5:
|
case 5:
|
||||||
//Order: [minibatch, channels, depth, height, width] -> [d, h, w, c]
|
//Order: [minibatch, channels, depth, height, width] -> [d, h, w, c]
|
||||||
return InputType.convolutional3D((int) inputArray.size(2), (int) inputArray.size(3),
|
return InputType.convolutional3D(inputArray.size(2), (int) inputArray.size(3),
|
||||||
(int) inputArray.size(4), (int) inputArray.size(1));
|
(int) inputArray.size(4), (int) inputArray.size(1));
|
||||||
default:
|
default:
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
|
|
|
@ -152,17 +152,18 @@ public class Cnn3DLossLayer extends FeedForwardLayer {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setNIn(int nIn){
|
public void setNIn(long nIn){
|
||||||
throw new UnsupportedOperationException(
|
throw new UnsupportedOperationException(
|
||||||
"Cnn3DLossLayer has no parameters, thus nIn will always equal nOut.");
|
"Cnn3DLossLayer has no parameters, thus nIn will always equal nOut.");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setNOut(int nOut){
|
public void setNOut(long nOut){
|
||||||
throw new UnsupportedOperationException(
|
throw new UnsupportedOperationException(
|
||||||
"Cnn3DLossLayer has no parameters, thus nIn will always equal nOut.");
|
"Cnn3DLossLayer has no parameters, thus nIn will always equal nOut.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
public Cnn3DLossLayer build() {
|
public Cnn3DLossLayer build() {
|
||||||
|
|
|
@ -145,13 +145,13 @@ public class CnnLossLayer extends FeedForwardLayer {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setNIn(int nIn){
|
public void setNIn(long nIn){
|
||||||
throw new UnsupportedOperationException(
|
throw new UnsupportedOperationException(
|
||||||
"This layer has no parameters, thus nIn will always equal nOut.");
|
"This layer has no parameters, thus nIn will always equal nOut.");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setNOut(int nOut){
|
public void setNOut(long nOut){
|
||||||
throw new UnsupportedOperationException(
|
throw new UnsupportedOperationException(
|
||||||
"This layer has no parameters, thus nIn will always equal nOut.");
|
"This layer has no parameters, thus nIn will always equal nOut.");
|
||||||
}
|
}
|
||||||
|
|
|
@ -88,7 +88,7 @@ public class Convolution1DLayer extends ConvolutionLayer {
|
||||||
//Probably: user did InputType.recurrent(x) without specifying sequence length
|
//Probably: user did InputType.recurrent(x) without specifying sequence length
|
||||||
outLength = -1;
|
outLength = -1;
|
||||||
} else {
|
} else {
|
||||||
outLength = Convolution1DUtils.getOutputSize((int) inputTsLength, kernelSize[0], stride[0], padding[0],
|
outLength = Convolution1DUtils.getOutputSize(inputTsLength, kernelSize[0], stride[0], padding[0],
|
||||||
convolutionMode, dilation[0]);
|
convolutionMode, dilation[0]);
|
||||||
}
|
}
|
||||||
return InputType.recurrent(nOut, outLength);
|
return InputType.recurrent(nOut, outLength);
|
||||||
|
|
|
@ -117,14 +117,14 @@ public abstract class FeedForwardLayer extends BaseLayer {
|
||||||
* this is the input channels, otherwise is the previous layer size.
|
* this is the input channels, otherwise is the previous layer size.
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
protected int nIn = 0;
|
protected long nIn = 0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Number of inputs for the layer (usually the size of the last layer). <br> Note that for Convolutional layers,
|
* Number of inputs for the layer (usually the size of the last layer). <br> Note that for Convolutional layers,
|
||||||
* this is the input channels, otherwise is the previous layer size.
|
* this is the input channels, otherwise is the previous layer size.
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
protected int nOut = 0;
|
protected long nOut = 0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Number of inputs for the layer (usually the size of the last layer). <br> Note that for Convolutional layers,
|
* Number of inputs for the layer (usually the size of the last layer). <br> Note that for Convolutional layers,
|
||||||
|
@ -144,8 +144,7 @@ public abstract class FeedForwardLayer extends BaseLayer {
|
||||||
* @param nIn Number of inputs for the layer
|
* @param nIn Number of inputs for the layer
|
||||||
*/
|
*/
|
||||||
public T nIn(long nIn) {
|
public T nIn(long nIn) {
|
||||||
// FIXME: int cast
|
this.setNIn(nIn);
|
||||||
this.setNIn((int) nIn);
|
|
||||||
return (T) this;
|
return (T) this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -41,12 +41,9 @@ public class InputTypeUtil {
|
||||||
Class<?> layerClass) {
|
Class<?> layerClass) {
|
||||||
InputType.InputTypeConvolutional i = (InputType.InputTypeConvolutional) inputType;
|
InputType.InputTypeConvolutional i = (InputType.InputTypeConvolutional) inputType;
|
||||||
|
|
||||||
// FIXME: int cast
|
val hIn = i.getHeight();
|
||||||
val hIn = (int) i.getHeight();
|
val wIn = i.getWidth();
|
||||||
val wIn = (int) i.getWidth();
|
|
||||||
|
|
||||||
val inHeight = (int) i.getHeight();
|
|
||||||
val inWidth = (int) i.getWidth();
|
|
||||||
int padH = (padding == null ? 0 : padding[0]); //May be null for ConvolutionMode.Same
|
int padH = (padding == null ? 0 : padding[0]); //May be null for ConvolutionMode.Same
|
||||||
int padW = (padding == null ? 0 : padding[1]);
|
int padW = (padding == null ? 0 : padding[1]);
|
||||||
int kH = kernelSize[0];
|
int kH = kernelSize[0];
|
||||||
|
@ -69,13 +66,13 @@ public class InputTypeUtil {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (convolutionMode == ConvolutionMode.Same) {
|
if (convolutionMode == ConvolutionMode.Same) {
|
||||||
int hOut = stride[0] * hIn;
|
long hOut = stride[0] * hIn;
|
||||||
int wOut = stride[1] * wIn;
|
long wOut = stride[1] * wIn;
|
||||||
return InputType.convolutional(hOut, wOut, outputDepth);
|
return InputType.convolutional(hOut, wOut, outputDepth);
|
||||||
}
|
}
|
||||||
|
|
||||||
int hOut = sH * (hIn - 1) + kH - 2 * padH;
|
long hOut = sH * (hIn - 1) + kH - 2 * padH;
|
||||||
int wOut = sW * (wIn - 1) + kW - 2 * padW;
|
long wOut = sW * (wIn - 1) + kW - 2 * padW;
|
||||||
|
|
||||||
return InputType.convolutional(hOut, wOut, outputDepth);
|
return InputType.convolutional(hOut, wOut, outputDepth);
|
||||||
}
|
}
|
||||||
|
@ -91,10 +88,9 @@ public class InputTypeUtil {
|
||||||
|
|
||||||
InputType.InputTypeConvolutional3D i = (InputType.InputTypeConvolutional3D) inputType;
|
InputType.InputTypeConvolutional3D i = (InputType.InputTypeConvolutional3D) inputType;
|
||||||
|
|
||||||
// FIXME: int cast
|
long inDepth = i.getDepth();
|
||||||
val inDepth = (int) i.getDepth();
|
long inHeight = i.getHeight();
|
||||||
val inHeight = (int) i.getHeight();
|
long inWidth = i.getWidth();
|
||||||
val inWidth = (int) i.getWidth();
|
|
||||||
|
|
||||||
int padD = (padding == null ? 0 : padding[0]);
|
int padD = (padding == null ? 0 : padding[0]);
|
||||||
int padH = (padding == null ? 0 : padding[1]);
|
int padH = (padding == null ? 0 : padding[1]);
|
||||||
|
@ -211,9 +207,9 @@ public class InputTypeUtil {
|
||||||
return InputType.convolutional3D(outD, outH, outW, outputChannels);
|
return InputType.convolutional3D(outD, outH, outW, outputChannels);
|
||||||
}
|
}
|
||||||
|
|
||||||
int dOut = (inDepth - kD + 2 * padD) / sD + 1;
|
long dOut = (inDepth - kD + 2 * padD) / sD + 1;
|
||||||
int hOut = (inHeight - kH + 2 * padH) / sH + 1;
|
long hOut = (inHeight - kH + 2 * padH) / sH + 1;
|
||||||
int wOut = (inWidth - kW + 2 * padW) / sW + 1;
|
long wOut = (inWidth - kW + 2 * padW) / sW + 1;
|
||||||
return InputType.convolutional3D(dOut, hOut, wOut, outputChannels);
|
return InputType.convolutional3D(dOut, hOut, wOut, outputChannels);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -296,9 +292,8 @@ public class InputTypeUtil {
|
||||||
|
|
||||||
InputType.InputTypeConvolutional i = (InputType.InputTypeConvolutional) inputType;
|
InputType.InputTypeConvolutional i = (InputType.InputTypeConvolutional) inputType;
|
||||||
|
|
||||||
// FIXME: int cast
|
long inHeight = i.getHeight();
|
||||||
val inHeight = (int) i.getHeight();
|
long inWidth = i.getWidth();
|
||||||
val inWidth = (int) i.getWidth();
|
|
||||||
int padH = (padding == null ? 0 : padding[0]); //May be null for ConvolutionMode.Same
|
int padH = (padding == null ? 0 : padding[0]); //May be null for ConvolutionMode.Same
|
||||||
int padW = (padding == null ? 0 : padding[1]);
|
int padW = (padding == null ? 0 : padding[1]);
|
||||||
int kH = kernelSize[0];
|
int kH = kernelSize[0];
|
||||||
|
@ -379,8 +374,8 @@ public class InputTypeUtil {
|
||||||
return InputType.convolutional(outH, outW, outputDepth);
|
return InputType.convolutional(outH, outW, outputDepth);
|
||||||
}
|
}
|
||||||
|
|
||||||
int hOut = (inHeight - kH + 2 * padH) / sH + 1;
|
long hOut = (inHeight - kH + 2 * padH) / sH + 1;
|
||||||
int wOut = (inWidth - kW + 2 * padW) / sW + 1;
|
long wOut = (inWidth - kW + 2 * padW) / sW + 1;
|
||||||
return InputType.convolutional(hOut, wOut, outputDepth);
|
return InputType.convolutional(hOut, wOut, outputDepth);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -145,7 +145,7 @@ public class LocallyConnected1D extends SameDiffLayer {
|
||||||
val weightsShape = new long[] {outputSize, featureDim, nOut};
|
val weightsShape = new long[] {outputSize, featureDim, nOut};
|
||||||
params.addWeightParam(ConvolutionParamInitializer.WEIGHT_KEY, weightsShape);
|
params.addWeightParam(ConvolutionParamInitializer.WEIGHT_KEY, weightsShape);
|
||||||
if (hasBias) {
|
if (hasBias) {
|
||||||
val biasShape = new long[] {1, nOut};
|
val biasShape = new long[] {nOut};
|
||||||
params.addBiasParam(ConvolutionParamInitializer.BIAS_KEY, biasShape);
|
params.addBiasParam(ConvolutionParamInitializer.BIAS_KEY, biasShape);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -200,7 +200,7 @@ public class LocallyConnected1D extends SameDiffLayer {
|
||||||
|
|
||||||
if (hasBias) {
|
if (hasBias) {
|
||||||
SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY);
|
SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY);
|
||||||
SDVariable biasAddedResult = sameDiff.nn().biasAdd(result, b);
|
SDVariable biasAddedResult = sameDiff.nn().biasAdd(result, b, true);
|
||||||
return activation.asSameDiff("out", sameDiff, biasAddedResult);
|
return activation.asSameDiff("out", sameDiff, biasAddedResult);
|
||||||
} else {
|
} else {
|
||||||
return activation.asSameDiff("out", sameDiff, result);
|
return activation.asSameDiff("out", sameDiff, result);
|
||||||
|
|
|
@ -145,7 +145,7 @@ public class LocallyConnected2D extends SameDiffLayer {
|
||||||
val weightsShape = new long[] {outputSize[0] * outputSize[1], featureDim, nOut};
|
val weightsShape = new long[] {outputSize[0] * outputSize[1], featureDim, nOut};
|
||||||
params.addWeightParam(ConvolutionParamInitializer.WEIGHT_KEY, weightsShape);
|
params.addWeightParam(ConvolutionParamInitializer.WEIGHT_KEY, weightsShape);
|
||||||
if (hasBias) {
|
if (hasBias) {
|
||||||
val biasShape = new long[] {1, nOut};
|
val biasShape = new long[] {nOut};
|
||||||
params.addBiasParam(ConvolutionParamInitializer.BIAS_KEY, biasShape);
|
params.addBiasParam(ConvolutionParamInitializer.BIAS_KEY, biasShape);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -211,7 +211,7 @@ public class LocallyConnected2D extends SameDiffLayer {
|
||||||
|
|
||||||
if (hasBias) {
|
if (hasBias) {
|
||||||
SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY);
|
SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY);
|
||||||
SDVariable biasAddedResult = sameDiff.nn().biasAdd(permutedResult, b);
|
SDVariable biasAddedResult = sameDiff.nn().biasAdd(permutedResult, b, true);
|
||||||
return activation.asSameDiff("out", sameDiff, biasAddedResult);
|
return activation.asSameDiff("out", sameDiff, biasAddedResult);
|
||||||
} else {
|
} else {
|
||||||
return activation.asSameDiff("out", sameDiff, permutedResult);
|
return activation.asSameDiff("out", sameDiff, permutedResult);
|
||||||
|
|
|
@ -142,13 +142,13 @@ public class RnnLossLayer extends FeedForwardLayer {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setNIn(int nIn){
|
public void setNIn(long nIn){
|
||||||
throw new UnsupportedOperationException(
|
throw new UnsupportedOperationException(
|
||||||
"This layer has no parameters, thus nIn will always equal nOut.");
|
"This layer has no parameters, thus nIn will always equal nOut.");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setNOut(int nOut){
|
public void setNOut(long nOut){
|
||||||
throw new UnsupportedOperationException(
|
throw new UnsupportedOperationException(
|
||||||
"This layer has no parameters, thus nIn will always equal nOut.");
|
"This layer has no parameters, thus nIn will always equal nOut.");
|
||||||
}
|
}
|
||||||
|
|
|
@ -82,12 +82,12 @@ public class Subsampling1DLayer extends SubsamplingLayer {
|
||||||
}
|
}
|
||||||
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
|
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
|
||||||
long inputTsLength = r.getTimeSeriesLength();
|
long inputTsLength = r.getTimeSeriesLength();
|
||||||
int outLength;
|
long outLength;
|
||||||
if (inputTsLength < 0) {
|
if (inputTsLength < 0) {
|
||||||
//Probably: user did InputType.recurrent(x) without specifying sequence length
|
//Probably: user did InputType.recurrent(x) without specifying sequence length
|
||||||
outLength = -1;
|
outLength = -1;
|
||||||
} else {
|
} else {
|
||||||
outLength = Convolution1DUtils.getOutputSize((int) inputTsLength, kernelSize[0], stride[0], padding[0],
|
outLength = Convolution1DUtils.getOutputSize(inputTsLength, kernelSize[0], stride[0], padding[0],
|
||||||
convolutionMode, dilation[0]);
|
convolutionMode, dilation[0]);
|
||||||
}
|
}
|
||||||
return InputType.recurrent(r.getSize(), outLength);
|
return InputType.recurrent(r.getSize(), outLength);
|
||||||
|
|
|
@ -32,6 +32,7 @@ import org.deeplearning4j.util.ValidationUtils;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.exception.ND4JArraySizeException;
|
||||||
import org.nd4j.linalg.learning.regularization.Regularization;
|
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||||
|
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
|
@ -138,9 +139,11 @@ public class Subsampling3DLayer extends NoParamLayer {
|
||||||
+ "\"): Expected CNN input, got " + inputType);
|
+ "\"): Expected CNN input, got " + inputType);
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME: int cast
|
long inChannels = ((InputType.InputTypeConvolutional3D) inputType).getChannels();
|
||||||
|
if (inChannels > Integer.MAX_VALUE)
|
||||||
|
throw new ND4JArraySizeException();
|
||||||
return InputTypeUtil.getOutputTypeCnn3DLayers(inputType, kernelSize, stride, padding, new int[] {1, 1, 1}, // no dilation
|
return InputTypeUtil.getOutputTypeCnn3DLayers(inputType, kernelSize, stride, padding, new int[] {1, 1, 1}, // no dilation
|
||||||
convolutionMode, (int) ((InputType.InputTypeConvolutional3D) inputType).getChannels(),
|
convolutionMode, (int) inChannels,
|
||||||
layerIndex, getLayerName(), Subsampling3DLayer.class);
|
layerIndex, getLayerName(), Subsampling3DLayer.class);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -83,11 +83,10 @@ public class Upsampling3D extends BaseUpsamplingLayer {
|
||||||
}
|
}
|
||||||
InputType.InputTypeConvolutional3D i = (InputType.InputTypeConvolutional3D) inputType;
|
InputType.InputTypeConvolutional3D i = (InputType.InputTypeConvolutional3D) inputType;
|
||||||
|
|
||||||
// FIXME: int cast
|
long inHeight = (int) i.getHeight();
|
||||||
int inHeight = (int) i.getHeight();
|
long inWidth = (int) i.getWidth();
|
||||||
int inWidth = (int) i.getWidth();
|
long inDepth = (int) i.getDepth();
|
||||||
int inDepth = (int) i.getDepth();
|
long inChannels = (int) i.getChannels();
|
||||||
int inChannels = (int) i.getChannels();
|
|
||||||
|
|
||||||
return InputType.convolutional3D(size[0] * inDepth, size[1] * inHeight, size[2] * inWidth, inChannels);
|
return InputType.convolutional3D(size[0] * inDepth, size[1] * inHeight, size[2] * inWidth, inChannels);
|
||||||
}
|
}
|
||||||
|
|
|
@ -65,7 +65,7 @@ public abstract class SameDiffLambdaVertex extends SameDiffVertex {
|
||||||
defineVertex(temp, tempInputs);
|
defineVertex(temp, tempInputs);
|
||||||
List<String> list = new ArrayList<>();
|
List<String> list = new ArrayList<>();
|
||||||
for (Integer i : tempInputs.map.keySet()) {
|
for (Integer i : tempInputs.map.keySet()) {
|
||||||
list.add(tempInputs.map.get(i).getVarName());
|
list.add(tempInputs.map.get(i).name());
|
||||||
}
|
}
|
||||||
params.defineInputs(list.toArray(new String[list.size()]));
|
params.defineInputs(list.toArray(new String[list.size()]));
|
||||||
}
|
}
|
||||||
|
|
|
@ -259,7 +259,7 @@ public class OCNNOutputLayer extends BaseOutputLayer {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setNOut(int nOut){
|
public void setNOut(long nOut){
|
||||||
throw new UnsupportedOperationException(
|
throw new UnsupportedOperationException(
|
||||||
"Unable to specify number of outputs with ocnn. Outputs are fixed to 1.");
|
"Unable to specify number of outputs with ocnn. Outputs are fixed to 1.");
|
||||||
}
|
}
|
||||||
|
|
|
@ -79,6 +79,7 @@ import org.nd4j.linalg.dataset.api.DataSetUtil;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
|
import org.nd4j.linalg.exception.ND4JArraySizeException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.heartbeat.Heartbeat;
|
import org.nd4j.linalg.heartbeat.Heartbeat;
|
||||||
import org.nd4j.linalg.heartbeat.reports.Environment;
|
import org.nd4j.linalg.heartbeat.reports.Environment;
|
||||||
|
@ -3329,7 +3330,6 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
//In 99+% of cases, the input and labels dimension 0 size should be identical
|
//In 99+% of cases, the input and labels dimension 0 size should be identical
|
||||||
//The only real exceptions: space to batch, and batch to space layers
|
//The only real exceptions: space to batch, and batch to space layers
|
||||||
//In those cases, we should base it on the labels size, as this impacts gradient calculation
|
//In those cases, we should base it on the labels size, as this impacts gradient calculation
|
||||||
// FIXME: int cast
|
|
||||||
return labels == null || labels[0] == null ? (int) inputs[0].size(0) : (int)labels[0].size(0);
|
return labels == null || labels[0] == null ? (int) inputs[0].size(0) : (int)labels[0].size(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3653,7 +3653,8 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
if (endTimeIdx > timeSeriesLength)
|
if (endTimeIdx > timeSeriesLength)
|
||||||
endTimeIdx = timeSeriesLength;
|
endTimeIdx = timeSeriesLength;
|
||||||
|
|
||||||
// FIXME: int cast
|
if (startTimeIdx > Integer.MAX_VALUE)
|
||||||
|
throw new ND4JArraySizeException();
|
||||||
List<INDArray[]> list = getSubsetsForTbptt((int) startTimeIdx, endTimeIdx, inputs, labels, featureMasks, labelMasks);
|
List<INDArray[]> list = getSubsetsForTbptt((int) startTimeIdx, endTimeIdx, inputs, labels, featureMasks, labelMasks);
|
||||||
|
|
||||||
setInputs(list.get(0));
|
setInputs(list.get(0));
|
||||||
|
@ -3799,9 +3800,10 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME: int cast
|
if (minibatchSize > Integer.MAX_VALUE)
|
||||||
|
throw new ND4JArraySizeException();
|
||||||
Pair<INDArray, MaskState> outPair =
|
Pair<INDArray, MaskState> outPair =
|
||||||
current.feedForwardMaskArrays(inputMasks, maskState, (int) minibatchSize);
|
current.feedForwardMaskArrays(inputMasks, maskState, (int)minibatchSize);
|
||||||
map.put(topologicalOrder[i], outPair);
|
map.put(topologicalOrder[i], outPair);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4664,7 +4666,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
* @param layer Index of the layer to get the size of. Must be in range 0 to nLayers-1 inclusive
|
* @param layer Index of the layer to get the size of. Must be in range 0 to nLayers-1 inclusive
|
||||||
* @return Size of the layer
|
* @return Size of the layer
|
||||||
*/
|
*/
|
||||||
public int layerSize(int layer) {
|
public long layerSize(int layer) {
|
||||||
if (layer < 0 || layer > layers.length) {
|
if (layer < 0 || layer > layers.length) {
|
||||||
throw new IllegalArgumentException("Invalid layer index: " + layer + ". Layer index must be between 0 and "
|
throw new IllegalArgumentException("Invalid layer index: " + layer + ". Layer index must be between 0 and "
|
||||||
+ (layers.length - 1) + " inclusive");
|
+ (layers.length - 1) + " inclusive");
|
||||||
|
@ -4683,7 +4685,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
* @param layer Index of the layer to get the size of. Must be in range 0 to nLayers-1 inclusive
|
* @param layer Index of the layer to get the size of. Must be in range 0 to nLayers-1 inclusive
|
||||||
* @return Size of the layer
|
* @return Size of the layer
|
||||||
*/
|
*/
|
||||||
public int layerInputSize(int layer) {
|
public long layerInputSize(int layer) {
|
||||||
if (layer < 0 || layer > layers.length) {
|
if (layer < 0 || layer > layers.length) {
|
||||||
throw new IllegalArgumentException("Invalid layer index: " + layer + ". Layer index must be between 0 and "
|
throw new IllegalArgumentException("Invalid layer index: " + layer + ". Layer index must be between 0 and "
|
||||||
+ (layers.length - 1) + " inclusive");
|
+ (layers.length - 1) + " inclusive");
|
||||||
|
@ -4701,7 +4703,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
* @param layerName Name of the layer to get the size of
|
* @param layerName Name of the layer to get the size of
|
||||||
* @return Size of the layer
|
* @return Size of the layer
|
||||||
*/
|
*/
|
||||||
public int layerSize(String layerName) {
|
public long layerSize(String layerName) {
|
||||||
Layer l = getLayer(layerName);
|
Layer l = getLayer(layerName);
|
||||||
if(l == null){
|
if(l == null){
|
||||||
throw new IllegalArgumentException("No layer with name \"" + layerName + "\" exists");
|
throw new IllegalArgumentException("No layer with name \"" + layerName + "\" exists");
|
||||||
|
@ -4712,8 +4714,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
}
|
}
|
||||||
FeedForwardLayer ffl = (FeedForwardLayer) conf;
|
FeedForwardLayer ffl = (FeedForwardLayer) conf;
|
||||||
|
|
||||||
// FIXME: int cast
|
return ffl.getNOut();
|
||||||
return (int) ffl.getNOut();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -4727,7 +4728,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
* @param layerName Name of the layer to get the size of
|
* @param layerName Name of the layer to get the size of
|
||||||
* @return Size of the layer
|
* @return Size of the layer
|
||||||
*/
|
*/
|
||||||
public int layerInputSize(String layerName) {
|
public long layerInputSize(String layerName) {
|
||||||
Layer l = getLayer(layerName);
|
Layer l = getLayer(layerName);
|
||||||
if(l == null){
|
if(l == null){
|
||||||
throw new IllegalArgumentException("No layer with name \"" + layerName + "\" exists");
|
throw new IllegalArgumentException("No layer with name \"" + layerName + "\" exists");
|
||||||
|
@ -4738,8 +4739,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
}
|
}
|
||||||
FeedForwardLayer ffl = (FeedForwardLayer) conf;
|
FeedForwardLayer ffl = (FeedForwardLayer) conf;
|
||||||
|
|
||||||
// FIXME: int cast
|
return ffl.getNIn();
|
||||||
return (int) ffl.getNIn();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -114,7 +114,7 @@ public class MergeVertex extends BaseGraphVertex {
|
||||||
}
|
}
|
||||||
|
|
||||||
try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATIONS)){
|
try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATIONS)){
|
||||||
return Nd4j.hstack(in);
|
return Nd4j.concat(1, in);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -43,10 +43,10 @@ import java.util.Arrays;
|
||||||
* @author Justin Long (crockpotveggies)
|
* @author Justin Long (crockpotveggies)
|
||||||
*/
|
*/
|
||||||
public class UnstackVertex extends BaseGraphVertex {
|
public class UnstackVertex extends BaseGraphVertex {
|
||||||
private int from;
|
private long from;
|
||||||
private int stackSize;
|
private int stackSize;
|
||||||
private long forwardShape[];
|
private long forwardShape[];
|
||||||
private int step;
|
private long step;
|
||||||
|
|
||||||
public UnstackVertex(ComputationGraph graph, String name, int vertexIndex, int from, int stackSize, DataType dataType) {
|
public UnstackVertex(ComputationGraph graph, String name, int vertexIndex, int from, int stackSize, DataType dataType) {
|
||||||
this(graph, name, vertexIndex, null, null, from, stackSize, dataType);
|
this(graph, name, vertexIndex, null, null, from, stackSize, dataType);
|
||||||
|
@ -77,10 +77,9 @@ public class UnstackVertex extends BaseGraphVertex {
|
||||||
// once we know the inputs, save the shape and interval size for doBackward
|
// once we know the inputs, save the shape and interval size for doBackward
|
||||||
this.forwardShape = Arrays.copyOf(inputs[0].shape(), inputs[0].rank());
|
this.forwardShape = Arrays.copyOf(inputs[0].shape(), inputs[0].rank());
|
||||||
|
|
||||||
// FIXME: int cast
|
this.step = inputs[0].size(0) / stackSize;
|
||||||
this.step = (int) inputs[0].size(0) / stackSize;
|
long start = from * step;
|
||||||
int start = from * step;
|
long end = (from + 1) * step;
|
||||||
int end = (from + 1) * step;
|
|
||||||
|
|
||||||
INDArray ret;
|
INDArray ret;
|
||||||
switch (inputs[0].rank()) { //TODO remove the dups here if/when possible (gradient checks must pass)
|
switch (inputs[0].rank()) { //TODO remove the dups here if/when possible (gradient checks must pass)
|
||||||
|
@ -108,8 +107,8 @@ public class UnstackVertex extends BaseGraphVertex {
|
||||||
throw new IllegalStateException("Cannot do backward pass: error not set");
|
throw new IllegalStateException("Cannot do backward pass: error not set");
|
||||||
|
|
||||||
INDArray out = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, inputs[0].dataType(), forwardShape);
|
INDArray out = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, inputs[0].dataType(), forwardShape);
|
||||||
int start = from * step;
|
long start = from * step;
|
||||||
int end = (from + 1) * step;
|
long end = (from + 1) * step;
|
||||||
|
|
||||||
switch (forwardShape.length) {
|
switch (forwardShape.length) {
|
||||||
case 2:
|
case 2:
|
||||||
|
@ -154,8 +153,8 @@ public class UnstackVertex extends BaseGraphVertex {
|
||||||
}
|
}
|
||||||
|
|
||||||
//Mask arrays are either 1d (column vector) or 2d...
|
//Mask arrays are either 1d (column vector) or 2d...
|
||||||
int start = from * minibatchSize;
|
long start = from * minibatchSize;
|
||||||
int end = (from + 1) * minibatchSize;
|
long end = (from + 1) * minibatchSize;
|
||||||
INDArray outMask = maskArrays[0].get(NDArrayIndex.interval(start, end), NDArrayIndex.all());
|
INDArray outMask = maskArrays[0].get(NDArrayIndex.interval(start, end), NDArrayIndex.all());
|
||||||
return new Pair<>(outMask, currentMaskState);
|
return new Pair<>(outMask, currentMaskState);
|
||||||
}
|
}
|
||||||
|
|
|
@ -87,9 +87,8 @@ public class LastTimeStepVertex extends BaseGraphVertex {
|
||||||
|
|
||||||
INDArray out;
|
INDArray out;
|
||||||
if (mask == null) {
|
if (mask == null) {
|
||||||
// FIXME: int cast
|
|
||||||
//No mask array -> extract same (last) column for all
|
//No mask array -> extract same (last) column for all
|
||||||
int lastTS = (int) inputs[0].size(2) - 1;
|
long lastTS = inputs[0].size(2) - 1;
|
||||||
out = inputs[0].get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(lastTS));
|
out = inputs[0].get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(lastTS));
|
||||||
out = workspaceMgr.dup(ArrayType.ACTIVATIONS, out);
|
out = workspaceMgr.dup(ArrayType.ACTIVATIONS, out);
|
||||||
fwdPassTimeSteps = null; //Null -> last time step for all examples
|
fwdPassTimeSteps = null; //Null -> last time step for all examples
|
||||||
|
@ -99,8 +98,7 @@ public class LastTimeStepVertex extends BaseGraphVertex {
|
||||||
|
|
||||||
//Want the index of the last non-zero entry in the mask array.
|
//Want the index of the last non-zero entry in the mask array.
|
||||||
//Check a little here by using mulRowVector([0,1,2,3,...]) and argmax
|
//Check a little here by using mulRowVector([0,1,2,3,...]) and argmax
|
||||||
// FIXME: int cast
|
long maxTsLength = fwdPassShape[2];
|
||||||
int maxTsLength = (int) fwdPassShape[2];
|
|
||||||
INDArray row = Nd4j.linspace(0, maxTsLength - 1, maxTsLength, mask.dataType());
|
INDArray row = Nd4j.linspace(0, maxTsLength - 1, maxTsLength, mask.dataType());
|
||||||
INDArray temp = mask.mulRowVector(row);
|
INDArray temp = mask.mulRowVector(row);
|
||||||
INDArray lastElementIdx = Nd4j.argMax(temp, 1);
|
INDArray lastElementIdx = Nd4j.argMax(temp, 1);
|
||||||
|
|
|
@ -346,7 +346,6 @@ public abstract class AbstractLayer<LayerConfT extends org.deeplearning4j.nn.con
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int getInputMiniBatchSize() {
|
public int getInputMiniBatchSize() {
|
||||||
// FIXME: int cast
|
|
||||||
return (int) input.size(0);
|
return (int) input.size(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -229,7 +229,6 @@ public abstract class BaseOutputLayer<LayerConfT extends org.deeplearning4j.nn.c
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public int numLabels() {
|
public int numLabels() {
|
||||||
// FIXME: int cast
|
|
||||||
return (int) labels.size(1);
|
return (int) labels.size(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -236,7 +236,6 @@ public class LossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.LossL
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public int numLabels() {
|
public int numLabels() {
|
||||||
// FIXME: int cast
|
|
||||||
return (int) labels.size(1);
|
return (int) labels.size(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -86,19 +86,18 @@ public class Cnn3DLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.
|
||||||
INDArray delta2d = lossFunction.computeGradient(labels2d, input2d.dup(input2d.ordering()), layerConf().getActivationFn(), maskReshaped);
|
INDArray delta2d = lossFunction.computeGradient(labels2d, input2d.dup(input2d.ordering()), layerConf().getActivationFn(), maskReshaped);
|
||||||
delta2d = workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, delta2d);
|
delta2d = workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, delta2d);
|
||||||
|
|
||||||
// FIXME: int cast
|
long n = input.size(0);
|
||||||
int n = (int)input.size(0);
|
long d, h, w, c;
|
||||||
int d, h, w, c;
|
|
||||||
if(layerConf().getDataFormat() == Convolution3D.DataFormat.NDHWC){
|
if(layerConf().getDataFormat() == Convolution3D.DataFormat.NDHWC){
|
||||||
d = (int)input.size(1);
|
d = input.size(1);
|
||||||
h = (int)input.size(2);
|
h = input.size(2);
|
||||||
w = (int)input.size(3);
|
w = input.size(3);
|
||||||
c = (int)input.size(4);
|
c = input.size(4);
|
||||||
} else {
|
} else {
|
||||||
d = (int)input.size(2);
|
d = input.size(2);
|
||||||
h = (int)input.size(3);
|
h = input.size(3);
|
||||||
w = (int)input.size(4);
|
w = input.size(4);
|
||||||
c = (int)input.size(1);
|
c = input.size(1);
|
||||||
}
|
}
|
||||||
INDArray delta5d = ConvolutionUtils.reshape2dTo5d(layerConf().getDataFormat(), delta2d, n, d, h, w, c, workspaceMgr, ArrayType.ACTIVATION_GRAD);
|
INDArray delta5d = ConvolutionUtils.reshape2dTo5d(layerConf().getDataFormat(), delta2d, n, d, h, w, c, workspaceMgr, ArrayType.ACTIVATION_GRAD);
|
||||||
|
|
||||||
|
@ -130,7 +129,6 @@ public class Cnn3DLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int numLabels() {
|
public int numLabels() {
|
||||||
// FIXME: int cast
|
|
||||||
return (int) labels.size(1);
|
return (int) labels.size(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -180,10 +178,8 @@ public class Cnn3DLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.
|
||||||
INDArray input2d = ConvolutionUtils.reshape5dTo2d(layerConf().getDataFormat(), in, workspaceMgr, ArrayType.ACTIVATIONS);
|
INDArray input2d = ConvolutionUtils.reshape5dTo2d(layerConf().getDataFormat(), in, workspaceMgr, ArrayType.ACTIVATIONS);
|
||||||
INDArray out2d = layerConf().getActivationFn().getActivation(input2d, training);
|
INDArray out2d = layerConf().getActivationFn().getActivation(input2d, training);
|
||||||
|
|
||||||
// FIXME: int cast
|
long n = input.size(0);
|
||||||
|
long d, h, w, c;
|
||||||
int n = (int)input.size(0);
|
|
||||||
int d, h, w, c;
|
|
||||||
if(layerConf().getDataFormat() == Convolution3D.DataFormat.NDHWC){
|
if(layerConf().getDataFormat() == Convolution3D.DataFormat.NDHWC){
|
||||||
d = (int)input.size(1);
|
d = (int)input.size(1);
|
||||||
h = (int)input.size(2);
|
h = (int)input.size(2);
|
||||||
|
@ -262,19 +258,18 @@ public class Cnn3DLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.
|
||||||
val newShape = input.shape().clone();
|
val newShape = input.shape().clone();
|
||||||
newShape[1] = 1;
|
newShape[1] = 1;
|
||||||
|
|
||||||
// FIXME
|
long n = input.size(0);
|
||||||
int n = (int)input.size(0);
|
long d, h, w, c;
|
||||||
int d, h, w, c;
|
|
||||||
if(layerConf().getDataFormat() == Convolution3D.DataFormat.NDHWC){
|
if(layerConf().getDataFormat() == Convolution3D.DataFormat.NDHWC){
|
||||||
d = (int)input.size(1);
|
d = input.size(1);
|
||||||
h = (int)input.size(2);
|
h = input.size(2);
|
||||||
w = (int)input.size(3);
|
w = input.size(3);
|
||||||
c = (int)input.size(4);
|
c = input.size(4);
|
||||||
} else {
|
} else {
|
||||||
d = (int)input.size(2);
|
d = input.size(2);
|
||||||
h = (int)input.size(3);
|
h = input.size(3);
|
||||||
w = (int)input.size(4);
|
w = input.size(4);
|
||||||
c = (int)input.size(1);
|
c = input.size(1);
|
||||||
}
|
}
|
||||||
INDArray scoreArrayTs = ConvolutionUtils.reshape2dTo5d(layerConf().getDataFormat(), scoreArray, n, d, h, w, c, workspaceMgr, ArrayType.FF_WORKING_MEM);
|
INDArray scoreArrayTs = ConvolutionUtils.reshape2dTo5d(layerConf().getDataFormat(), scoreArray, n, d, h, w, c, workspaceMgr, ArrayType.FF_WORKING_MEM);
|
||||||
INDArray summedScores = scoreArrayTs.sum(1,2,3,4);
|
INDArray summedScores = scoreArrayTs.sum(1,2,3,4);
|
||||||
|
|
|
@ -88,8 +88,7 @@ public class CnnLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.Cn
|
||||||
INDArray delta2d = lossFunction.computeGradient(labels2d, input2d.dup(input2d.ordering()), layerConf().getActivationFn(), maskReshaped);
|
INDArray delta2d = lossFunction.computeGradient(labels2d, input2d.dup(input2d.ordering()), layerConf().getActivationFn(), maskReshaped);
|
||||||
delta2d = workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, delta2d);
|
delta2d = workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, delta2d);
|
||||||
|
|
||||||
// FIXME: int cast
|
INDArray delta4d = ConvolutionUtils.reshape2dTo4d(delta2d, input.shape(), workspaceMgr, ArrayType.ACTIVATION_GRAD);
|
||||||
INDArray delta4d = ConvolutionUtils.reshape2dTo4d(delta2d, ArrayUtil.toInts(input.shape()), workspaceMgr, ArrayType.ACTIVATION_GRAD);
|
|
||||||
|
|
||||||
// grab the empty gradient
|
// grab the empty gradient
|
||||||
Gradient gradient = new DefaultGradient();
|
Gradient gradient = new DefaultGradient();
|
||||||
|
@ -119,7 +118,6 @@ public class CnnLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.Cn
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int numLabels() {
|
public int numLabels() {
|
||||||
// FIXME: int cast
|
|
||||||
return (int) labels.size(1);
|
return (int) labels.size(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -169,8 +167,7 @@ public class CnnLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.Cn
|
||||||
INDArray input2d = ConvolutionUtils.reshape4dTo2d(in, workspaceMgr, ArrayType.ACTIVATIONS);
|
INDArray input2d = ConvolutionUtils.reshape4dTo2d(in, workspaceMgr, ArrayType.ACTIVATIONS);
|
||||||
INDArray out2d = layerConf().getActivationFn().getActivation(input2d, training);
|
INDArray out2d = layerConf().getActivationFn().getActivation(input2d, training);
|
||||||
|
|
||||||
// FIXME: int cast
|
return ConvolutionUtils.reshape2dTo4d(out2d, input.shape(), workspaceMgr, ArrayType.ACTIVATIONS);
|
||||||
return ConvolutionUtils.reshape2dTo4d(out2d, ArrayUtil.toInts(input.shape()), workspaceMgr, ArrayType.ACTIVATIONS);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -236,8 +233,7 @@ public class CnnLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.Cn
|
||||||
val newShape = input.shape().clone();
|
val newShape = input.shape().clone();
|
||||||
newShape[1] = 1;
|
newShape[1] = 1;
|
||||||
|
|
||||||
// FIXME
|
INDArray scoreArrayTs = ConvolutionUtils.reshape2dTo4d(scoreArray, newShape, workspaceMgr, ArrayType.FF_WORKING_MEM);
|
||||||
INDArray scoreArrayTs = ConvolutionUtils.reshape2dTo4d(scoreArray, ArrayUtil.toInts(newShape), workspaceMgr, ArrayType.FF_WORKING_MEM);
|
|
||||||
INDArray summedScores = scoreArrayTs.sum(1,2,3).reshape(scoreArrayTs.size(0), 1);
|
INDArray summedScores = scoreArrayTs.sum(1,2,3).reshape(scoreArrayTs.size(0), 1);
|
||||||
|
|
||||||
if (fullNetRegTerm != 0.0) {
|
if (fullNetRegTerm != 0.0) {
|
||||||
|
|
|
@ -71,8 +71,7 @@ public class Convolution3DLayer extends ConvolutionLayer {
|
||||||
|
|
||||||
boolean isNCDHW = layerConfig.getDataFormat() == Convolution3D.DataFormat.NCDHW;
|
boolean isNCDHW = layerConfig.getDataFormat() == Convolution3D.DataFormat.NCDHW;
|
||||||
|
|
||||||
// FIXME: int cast
|
long miniBatch = input.size(0);
|
||||||
int miniBatch = (int) input.size(0);
|
|
||||||
int inD = (int) (isNCDHW ? input.size(2) : input.size(1));
|
int inD = (int) (isNCDHW ? input.size(2) : input.size(1));
|
||||||
int inH = (int) (isNCDHW ? input.size(3) : input.size(2));
|
int inH = (int) (isNCDHW ? input.size(3) : input.size(2));
|
||||||
int inW = (int) (isNCDHW ? input.size(4) : input.size(3));
|
int inW = (int) (isNCDHW ? input.size(4) : input.size(3));
|
||||||
|
@ -189,8 +188,7 @@ public class Convolution3DLayer extends ConvolutionLayer {
|
||||||
+ " " + layerId());
|
+ " " + layerId());
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME: int cast
|
long miniBatch = input.size(0);
|
||||||
int miniBatch = (int) input.size(0);
|
|
||||||
int inputChannels = (int) (isNCDHW ? input.size(1) : input.size(4));
|
int inputChannels = (int) (isNCDHW ? input.size(1) : input.size(4));
|
||||||
int inD =(int) (isNCDHW ? input.size(2) : input.size(1));
|
int inD =(int) (isNCDHW ? input.size(2) : input.size(1));
|
||||||
int inH = (int) (isNCDHW ? input.size(3) : input.size(2));
|
int inH = (int) (isNCDHW ? input.size(3) : input.size(2));
|
||||||
|
|
|
@ -35,6 +35,7 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.shape.Shape;
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
import org.nd4j.linalg.convolution.Convolution;
|
import org.nd4j.linalg.convolution.Convolution;
|
||||||
|
import org.nd4j.linalg.exception.ND4JArraySizeException;
|
||||||
import org.nd4j.linalg.exception.ND4JOpProfilerException;
|
import org.nd4j.linalg.exception.ND4JOpProfilerException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
@ -113,13 +114,12 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
||||||
if(epsilon.dataType() != dataType)
|
if(epsilon.dataType() != dataType)
|
||||||
epsilon = epsilon.castTo(dataType);
|
epsilon = epsilon.castTo(dataType);
|
||||||
|
|
||||||
// FIXME: int cast
|
long miniBatch = input.size(0);
|
||||||
int miniBatch = (int) input.size(0);
|
|
||||||
int inH = (int) input.size(2);
|
int inH = (int) input.size(2);
|
||||||
int inW = (int) input.size(3);
|
int inW = (int) input.size(3);
|
||||||
|
|
||||||
int outDepth = (int) weights.size(0);
|
long outDepth = weights.size(0);
|
||||||
int inDepth = (int) weights.size(1);
|
long inDepth = weights.size(1);
|
||||||
int kH = (int) weights.size(2);
|
int kH = (int) weights.size(2);
|
||||||
int kW = (int) weights.size(3);
|
int kW = (int) weights.size(3);
|
||||||
|
|
||||||
|
@ -143,7 +143,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
||||||
INDArray biasGradView = gradientViews.get(ConvolutionParamInitializer.BIAS_KEY);
|
INDArray biasGradView = gradientViews.get(ConvolutionParamInitializer.BIAS_KEY);
|
||||||
INDArray weightGradView = gradientViews.get(ConvolutionParamInitializer.WEIGHT_KEY); //4d, c order. Shape: [outDepth,inDepth,kH,kW]
|
INDArray weightGradView = gradientViews.get(ConvolutionParamInitializer.WEIGHT_KEY); //4d, c order. Shape: [outDepth,inDepth,kH,kW]
|
||||||
INDArray weightGradView2df = Shape
|
INDArray weightGradView2df = Shape
|
||||||
.newShapeNoCopy(weightGradView, new int[] {outDepth, inDepth * kH * kW}, false).transpose();
|
.newShapeNoCopy(weightGradView, new long[]{outDepth, inDepth * kH * kW}, false).transpose();
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -204,7 +204,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
||||||
|
|
||||||
//Note: due to the permute in preOut, and the fact that we essentially do a preOut.muli(epsilon), this reshape
|
//Note: due to the permute in preOut, and the fact that we essentially do a preOut.muli(epsilon), this reshape
|
||||||
// should be zero-copy; only possible exception being sometimes with the "identity" activation case
|
// should be zero-copy; only possible exception being sometimes with the "identity" activation case
|
||||||
INDArray delta2d = delta.reshape('c', new int[] {outDepth, miniBatch * outH * outW}); //Shape.newShapeNoCopy(delta,new int[]{outDepth,miniBatch*outH*outW},false);
|
INDArray delta2d = delta.reshape('c', new long[] {outDepth, miniBatch * outH * outW}); //Shape.newShapeNoCopy(delta,new int[]{outDepth,miniBatch*outH*outW},false);
|
||||||
|
|
||||||
//Do im2col, but with order [miniB,outH,outW,depthIn,kH,kW]; but need to input [miniBatch,channels,kH,kW,outH,outW] given the current im2col implementation
|
//Do im2col, but with order [miniB,outH,outW,depthIn,kH,kW]; but need to input [miniBatch,channels,kH,kW,outH,outW] given the current im2col implementation
|
||||||
//To get this: create an array of the order we want, permute it to the order required by im2col implementation, and then do im2col on that
|
//To get this: create an array of the order we want, permute it to the order required by im2col implementation, and then do im2col on that
|
||||||
|
@ -231,7 +231,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
||||||
//Calculate epsilons for layer below, in 2d format (note: this is in 'image patch' format before col2im reduction)
|
//Calculate epsilons for layer below, in 2d format (note: this is in 'image patch' format before col2im reduction)
|
||||||
//Note: cc -> f mmul here, then reshape to 6d in f order
|
//Note: cc -> f mmul here, then reshape to 6d in f order
|
||||||
INDArray epsNext2d = w2d.mmul(delta2d); //TODO can we reuse im2col array instead of allocating new result array?
|
INDArray epsNext2d = w2d.mmul(delta2d); //TODO can we reuse im2col array instead of allocating new result array?
|
||||||
INDArray eps6d = Shape.newShapeNoCopy(epsNext2d, new int[] {kW, kH, inDepth, outW, outH, miniBatch}, true);
|
INDArray eps6d = Shape.newShapeNoCopy(epsNext2d, new long[] {kW, kH, inDepth, outW, outH, miniBatch}, true);
|
||||||
|
|
||||||
//Calculate epsilonNext by doing im2col reduction.
|
//Calculate epsilonNext by doing im2col reduction.
|
||||||
//Current col2im implementation expects input with order: [miniBatch,channels,kH,kW,outH,outW]
|
//Current col2im implementation expects input with order: [miniBatch,channels,kH,kW,outH,outW]
|
||||||
|
@ -282,7 +282,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void validateInputDepth(int inDepth) {
|
protected void validateInputDepth(long inDepth) {
|
||||||
if (input.size(1) != inDepth) {
|
if (input.size(1) != inDepth) {
|
||||||
String layerName = conf.getLayer().getLayerName();
|
String layerName = conf.getLayer().getLayerName();
|
||||||
if (layerName == null)
|
if (layerName == null)
|
||||||
|
@ -313,14 +313,13 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
||||||
|
|
||||||
INDArray input = this.input.castTo(dataType);
|
INDArray input = this.input.castTo(dataType);
|
||||||
|
|
||||||
// FIXME: int cast
|
long miniBatch = input.size(0);
|
||||||
int miniBatch = (int) input.size(0);
|
long outDepth = weights.size(0);
|
||||||
int outDepth = (int) weights.size(0);
|
long inDepth = weights.size(1);
|
||||||
int inDepth = (int) weights.size(1);
|
|
||||||
validateInputDepth(inDepth);
|
validateInputDepth(inDepth);
|
||||||
|
|
||||||
int kH = (int) weights.size(2);
|
long kH = weights.size(2);
|
||||||
int kW = (int) weights.size(3);
|
long kW = weights.size(3);
|
||||||
|
|
||||||
int[] dilation = layerConf().getDilation();
|
int[] dilation = layerConf().getDilation();
|
||||||
int[] kernel = layerConf().getKernelSize();
|
int[] kernel = layerConf().getKernelSize();
|
||||||
|
@ -331,7 +330,8 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
||||||
if (convolutionMode == ConvolutionMode.Same) {
|
if (convolutionMode == ConvolutionMode.Same) {
|
||||||
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation); //Also performs validation
|
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation); //Also performs validation
|
||||||
|
|
||||||
// FIXME: int cast
|
if (input.size(2) > Integer.MAX_VALUE || input.size(3) > Integer.MAX_VALUE)
|
||||||
|
throw new ND4JArraySizeException();
|
||||||
pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) input.size(2), (int) input.size(3)}, kernel,
|
pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) input.size(2), (int) input.size(3)}, kernel,
|
||||||
strides, dilation );
|
strides, dilation );
|
||||||
} else {
|
} else {
|
||||||
|
@ -397,10 +397,12 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
||||||
INDArray col = Nd4j.createUninitialized(weights.dataType(), new long[] {miniBatch, outH, outW, inDepth, kH, kW}, 'c');
|
INDArray col = Nd4j.createUninitialized(weights.dataType(), new long[] {miniBatch, outH, outW, inDepth, kH, kW}, 'c');
|
||||||
INDArray col2 = col.permute(0, 3, 4, 5, 1, 2);
|
INDArray col2 = col.permute(0, 3, 4, 5, 1, 2);
|
||||||
INDArray im2ColIn = input.castTo(col2.dataType()); //No op if already (for example) float
|
INDArray im2ColIn = input.castTo(col2.dataType()); //No op if already (for example) float
|
||||||
Convolution.im2col(im2ColIn, kH, kW, strides[0], strides[1], pad[0], pad[1], dilation[0], dilation[1],
|
if (kH > Integer.MAX_VALUE || kW > Integer.MAX_VALUE)
|
||||||
|
throw new ND4JArraySizeException();
|
||||||
|
Convolution.im2col(im2ColIn, (int)kH, (int)kW, strides[0], strides[1], pad[0], pad[1], dilation[0], dilation[1],
|
||||||
convolutionMode == ConvolutionMode.Same, col2);
|
convolutionMode == ConvolutionMode.Same, col2);
|
||||||
|
|
||||||
INDArray im2col2d = Shape.newShapeNoCopy(col, new int[] {miniBatch * outH * outW, inDepth * kH * kW}, false);
|
INDArray im2col2d = Shape.newShapeNoCopy(col, new long[] {miniBatch * outH * outW, inDepth * kH * kW}, false);
|
||||||
|
|
||||||
//Current order of weights: [depthOut,depthIn,kH,kW], c order
|
//Current order of weights: [depthOut,depthIn,kH,kW], c order
|
||||||
//Permute to give [kW,kH,depthIn,depthOut], f order
|
//Permute to give [kW,kH,depthIn,depthOut], f order
|
||||||
|
@ -418,7 +420,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
||||||
}
|
}
|
||||||
|
|
||||||
//Now, reshape to [outW,outH,miniBatch,outDepth], and permute to have correct output order: [miniBath,outDepth,outH,outW];
|
//Now, reshape to [outW,outH,miniBatch,outDepth], and permute to have correct output order: [miniBath,outDepth,outH,outW];
|
||||||
z = Shape.newShapeNoCopy(z, new int[] {outW, outH, miniBatch, outDepth}, true);
|
z = Shape.newShapeNoCopy(z, new long[] {outW, outH, miniBatch, outDepth}, true);
|
||||||
z = z.permute(2, 3, 1, 0);
|
z = z.permute(2, 3, 1, 0);
|
||||||
|
|
||||||
if (training && cacheMode != CacheMode.NONE && workspaceMgr.hasConfiguration(ArrayType.FF_CACHE) && workspaceMgr.isWorkspaceOpen(ArrayType.FF_CACHE)) {
|
if (training && cacheMode != CacheMode.NONE && workspaceMgr.hasConfiguration(ArrayType.FF_CACHE) && workspaceMgr.isWorkspaceOpen(ArrayType.FF_CACHE)) {
|
||||||
|
|
|
@ -171,13 +171,14 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
|
||||||
+ " " + layerId());
|
+ " " + layerId());
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME: int cast
|
long inDepth = weights.size(0);
|
||||||
int inDepth = (int) weights.size(0);
|
long outDepth = weights.size(1);
|
||||||
int outDepth = (int) weights.size(1);
|
|
||||||
|
|
||||||
if (input.size(1) != inDepth && input.size(3) == inDepth) {
|
if (input.size(1) != inDepth && input.size(3) == inDepth) {
|
||||||
|
//TODO AB 2019/10/25 this is an ugly "pseudo-NHWC support" hack that needs to be removed ASAD
|
||||||
|
//https://github.com/eclipse/deeplearning4j/issues/8315
|
||||||
input = input.permute(0, 3, 1, 2);
|
input = input.permute(0, 3, 1, 2);
|
||||||
} else if (input.size(1) != inDepth && input.size(3) != inDepth) {
|
} else if (input.size(1) != inDepth ) {
|
||||||
String layerName = conf.getLayer().getLayerName();
|
String layerName = conf.getLayer().getLayerName();
|
||||||
if (layerName == null)
|
if (layerName == null)
|
||||||
layerName = "(not named)";
|
layerName = "(not named)";
|
||||||
|
@ -197,7 +198,6 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
|
||||||
int[] pad;
|
int[] pad;
|
||||||
int[] outSize;
|
int[] outSize;
|
||||||
if (convolutionMode == ConvolutionMode.Same) {
|
if (convolutionMode == ConvolutionMode.Same) {
|
||||||
// FIXME: int cast
|
|
||||||
outSize = ConvolutionUtils.getDeconvolutionOutputSize(input, kernel, strides, null, convolutionMode, dilation); //Also performs validation
|
outSize = ConvolutionUtils.getDeconvolutionOutputSize(input, kernel, strides, null, convolutionMode, dilation); //Also performs validation
|
||||||
pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) input.size(2), (int) input.size(3)}, kernel,
|
pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) input.size(2), (int) input.size(3)}, kernel,
|
||||||
strides, dilation );
|
strides, dilation );
|
||||||
|
@ -206,8 +206,8 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
|
||||||
outSize = ConvolutionUtils.getDeconvolutionOutputSize(input, kernel, strides, pad, convolutionMode, dilation); //Also performs validation
|
outSize = ConvolutionUtils.getDeconvolutionOutputSize(input, kernel, strides, pad, convolutionMode, dilation); //Also performs validation
|
||||||
}
|
}
|
||||||
|
|
||||||
int outH = outSize[0];
|
long outH = outSize[0];
|
||||||
int outW = outSize[1];
|
long outW = outSize[1];
|
||||||
|
|
||||||
|
|
||||||
val miniBatch = input.size(0);
|
val miniBatch = input.size(0);
|
||||||
|
|
|
@ -32,6 +32,7 @@ import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.CustomOp;
|
import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.exception.ND4JArraySizeException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
|
@ -75,12 +76,11 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer {
|
||||||
|
|
||||||
INDArray input = this.input.castTo(dataType); //No-op if correct type
|
INDArray input = this.input.castTo(dataType); //No-op if correct type
|
||||||
|
|
||||||
// FIXME: int cast
|
long miniBatch = input.size(0);
|
||||||
int miniBatch = (int) input.size(0);
|
int inH = (int)input.size(2);
|
||||||
int inH = (int) input.size(2);
|
int inW = (int)input.size(3);
|
||||||
int inW = (int) input.size(3);
|
|
||||||
|
|
||||||
int inDepth = (int) depthWiseWeights.size(2);
|
long inDepth = depthWiseWeights.size(2);
|
||||||
int kH = (int) depthWiseWeights.size(0);
|
int kH = (int) depthWiseWeights.size(0);
|
||||||
int kW = (int) depthWiseWeights.size(1);
|
int kW = (int) depthWiseWeights.size(1);
|
||||||
|
|
||||||
|
@ -169,10 +169,9 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer {
|
||||||
|
|
||||||
INDArray input = this.input.castTo(dataType); //no-op if correct dtype
|
INDArray input = this.input.castTo(dataType); //no-op if correct dtype
|
||||||
|
|
||||||
// FIXME: int cast
|
long inDepth = depthWiseWeights.size(2);
|
||||||
int inDepth = (int) depthWiseWeights.size(2);
|
long depthMultiplier = depthWiseWeights.size(3);
|
||||||
int depthMultiplier = (int) depthWiseWeights.size(3);
|
long outDepth = depthMultiplier * inDepth;
|
||||||
int outDepth = depthMultiplier * inDepth;
|
|
||||||
|
|
||||||
if (input.size(1) != inDepth) {
|
if (input.size(1) != inDepth) {
|
||||||
String layerName = conf.getLayer().getLayerName();
|
String layerName = conf.getLayer().getLayerName();
|
||||||
|
@ -197,7 +196,9 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer {
|
||||||
if (convolutionMode == ConvolutionMode.Same) {
|
if (convolutionMode == ConvolutionMode.Same) {
|
||||||
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation);
|
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation);
|
||||||
|
|
||||||
// FIXME: int cast
|
if (input.size(2) > Integer.MAX_VALUE || input.size(3) > Integer.MAX_VALUE) {
|
||||||
|
throw new ND4JArraySizeException();
|
||||||
|
}
|
||||||
pad = ConvolutionUtils.getSameModeTopLeftPadding(
|
pad = ConvolutionUtils.getSameModeTopLeftPadding(
|
||||||
outSize, new int[]{(int) input.size(2), (int) input.size(3)}, kernel, strides, dilation);
|
outSize, new int[]{(int) input.size(2), (int) input.size(3)}, kernel, strides, dilation);
|
||||||
} else {
|
} else {
|
||||||
|
@ -205,8 +206,8 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer {
|
||||||
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation);
|
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation);
|
||||||
}
|
}
|
||||||
|
|
||||||
int outH = outSize[0];
|
long outH = outSize[0];
|
||||||
int outW = outSize[1];
|
long outW = outSize[1];
|
||||||
|
|
||||||
val miniBatch = input.size(0);
|
val miniBatch = input.size(0);
|
||||||
INDArray output = workspaceMgr.create(
|
INDArray output = workspaceMgr.create(
|
||||||
|
|
|
@ -33,6 +33,7 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.CustomOp;
|
import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.exception.ND4JArraySizeException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
|
@ -90,10 +91,9 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer {
|
||||||
|
|
||||||
INDArray input = this.input.castTo(dataType);
|
INDArray input = this.input.castTo(dataType);
|
||||||
|
|
||||||
// FIXME: int cast
|
long miniBatch = input.size(0);
|
||||||
int miniBatch = (int) input.size(0);
|
int inH = (int)input.size(2);
|
||||||
int inH = (int) input.size(2);
|
int inW = (int)input.size(3);
|
||||||
int inW = (int) input.size(3);
|
|
||||||
|
|
||||||
int inDepth = (int) depthWiseWeights.size(1);
|
int inDepth = (int) depthWiseWeights.size(1);
|
||||||
int kH = (int) depthWiseWeights.size(2);
|
int kH = (int) depthWiseWeights.size(2);
|
||||||
|
@ -194,9 +194,8 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer {
|
||||||
+ " " + layerId());
|
+ " " + layerId());
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME: int cast
|
long inDepth = depthWiseWeights.size(1);
|
||||||
int inDepth = (int) depthWiseWeights.size(1);
|
long outDepth = pointWiseWeights.size(0);
|
||||||
int outDepth = (int) pointWiseWeights.size(0);
|
|
||||||
|
|
||||||
if (input.size(1) != inDepth) {
|
if (input.size(1) != inDepth) {
|
||||||
String layerName = conf.getLayer().getLayerName();
|
String layerName = conf.getLayer().getLayerName();
|
||||||
|
@ -220,7 +219,9 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer {
|
||||||
if (convolutionMode == ConvolutionMode.Same) {
|
if (convolutionMode == ConvolutionMode.Same) {
|
||||||
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation); //Also performs validation
|
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation); //Also performs validation
|
||||||
|
|
||||||
// FIXME: int cast
|
if (input.size(2) > Integer.MAX_VALUE || input.size(3) > Integer.MAX_VALUE) {
|
||||||
|
throw new ND4JArraySizeException();
|
||||||
|
}
|
||||||
pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) input.size(2), (int) input.size(3)}, kernel,
|
pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) input.size(2), (int) input.size(3)}, kernel,
|
||||||
strides, dilation );
|
strides, dilation );
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -75,11 +75,10 @@ public class SpaceToDepth extends AbstractLayer<org.deeplearning4j.nn.conf.layer
|
||||||
public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
|
public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
|
||||||
assertInputSet(true);
|
assertInputSet(true);
|
||||||
|
|
||||||
// FIXME: int cast
|
long miniBatch = input.size(0);
|
||||||
int miniBatch = (int) input.size(0);
|
long inDepth = input.size(1);
|
||||||
int inDepth = (int) input.size(1);
|
long inH = input.size(2);
|
||||||
int inH = (int) input.size(2);
|
long inW = input.size(3);
|
||||||
int inW = (int) input.size(3);
|
|
||||||
|
|
||||||
INDArray input = this.input.castTo(dataType); //No-op if already correct type
|
INDArray input = this.input.castTo(dataType); //No-op if already correct type
|
||||||
|
|
||||||
|
@ -122,17 +121,16 @@ public class SpaceToDepth extends AbstractLayer<org.deeplearning4j.nn.conf.layer
|
||||||
return preOutput;
|
return preOutput;
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME: int cast
|
long miniBatch = input.size(0);
|
||||||
int miniBatch = (int) input.size(0);
|
long depth = input.size(1);
|
||||||
int depth = (int) input.size(1);
|
long inH = input.size(2);
|
||||||
int inH = (int) input.size(2);
|
long inW = input.size(3);
|
||||||
int inW = (int) input.size(3);
|
|
||||||
|
|
||||||
int blockSize = getBlockSize();
|
int blockSize = getBlockSize();
|
||||||
|
|
||||||
int outH = inH / blockSize;
|
long outH = inH / blockSize;
|
||||||
int outW = inW / blockSize;
|
long outW = inW / blockSize;
|
||||||
int outDepth = depth * blockSize * blockSize;
|
long outDepth = depth * blockSize * blockSize;
|
||||||
|
|
||||||
INDArray out = workspaceMgr.create(ArrayType.ACTIVATIONS, input.dataType(), new long[]{1, miniBatch * outDepth * outH * outW}, 'c');
|
INDArray out = workspaceMgr.create(ArrayType.ACTIVATIONS, input.dataType(), new long[]{1, miniBatch * outDepth * outH * outW}, 'c');
|
||||||
INDArray reshapedOut;
|
INDArray reshapedOut;
|
||||||
|
|
|
@ -71,9 +71,8 @@ public class Subsampling3DLayer extends AbstractLayer<org.deeplearning4j.nn.conf
|
||||||
|
|
||||||
boolean isNCDHW = layerConf().getDataFormat() == Convolution3D.DataFormat.NCDHW;
|
boolean isNCDHW = layerConf().getDataFormat() == Convolution3D.DataFormat.NCDHW;
|
||||||
|
|
||||||
// FIXME: int cast
|
long miniBatch = input.size(0);
|
||||||
int miniBatch = (int) input.size(0);
|
long inChannels = isNCDHW ? input.size(1) : input.size(4);
|
||||||
int inChannels = (int) (isNCDHW ? input.size(1) : input.size(4));
|
|
||||||
int inD = (int) (isNCDHW ? input.size(2) : input.size(1));
|
int inD = (int) (isNCDHW ? input.size(2) : input.size(1));
|
||||||
int inH = (int) (isNCDHW ? input.size(3) : input.size(2));
|
int inH = (int) (isNCDHW ? input.size(3) : input.size(2));
|
||||||
int inW = (int) (isNCDHW ? input.size(4) : input.size(3));
|
int inW = (int) (isNCDHW ? input.size(4) : input.size(3));
|
||||||
|
@ -148,9 +147,8 @@ public class Subsampling3DLayer extends AbstractLayer<org.deeplearning4j.nn.conf
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME: int cast
|
long miniBatch = input.size(0);
|
||||||
int miniBatch = (int) input.size(0);
|
long inChannels = isNCDHW ? input.size(1) : input.size(4);
|
||||||
int inChannels = (int) (isNCDHW ? input.size(1) : input.size(4));
|
|
||||||
int inD = (int) (isNCDHW ? input.size(2) : input.size(1));
|
int inD = (int) (isNCDHW ? input.size(2) : input.size(1));
|
||||||
int inH = (int) (isNCDHW ? input.size(3) : input.size(2));
|
int inH = (int) (isNCDHW ? input.size(3) : input.size(2));
|
||||||
int inW = (int) (isNCDHW ? input.size(4) : input.size(3));
|
int inW = (int) (isNCDHW ? input.size(4) : input.size(3));
|
||||||
|
@ -170,9 +168,9 @@ public class Subsampling3DLayer extends AbstractLayer<org.deeplearning4j.nn.conf
|
||||||
outSize = Convolution3DUtils.get3DOutputSize(
|
outSize = Convolution3DUtils.get3DOutputSize(
|
||||||
input, kernel, strides, pad, convolutionMode, dilation, isNCDHW);
|
input, kernel, strides, pad, convolutionMode, dilation, isNCDHW);
|
||||||
}
|
}
|
||||||
int outD = outSize[0];
|
long outD = outSize[0];
|
||||||
int outH = outSize[1];
|
long outH = outSize[1];
|
||||||
int outW = outSize[2];
|
long outW = outSize[2];
|
||||||
|
|
||||||
String opName = layerConf().getPoolingType() == PoolingType.MAX ? "maxpool3dnew" : "avgpool3dnew";
|
String opName = layerConf().getPoolingType() == PoolingType.MAX ? "maxpool3dnew" : "avgpool3dnew";
|
||||||
|
|
||||||
|
|
|
@ -108,11 +108,8 @@ public class SubsamplingLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l
|
||||||
if(epsilon.dataType() != dataType)
|
if(epsilon.dataType() != dataType)
|
||||||
epsilon = epsilon.castTo(dataType);
|
epsilon = epsilon.castTo(dataType);
|
||||||
|
|
||||||
// FIXME: int cast
|
int inH = (int)input.size(2);
|
||||||
int miniBatch = (int) input.size(0);
|
int inW = (int)input.size(3);
|
||||||
int inDepth = (int) input.size(1);
|
|
||||||
int inH = (int) input.size(2);
|
|
||||||
int inW = (int) input.size(3);
|
|
||||||
|
|
||||||
int[] kernel = layerConf().getKernelSize();
|
int[] kernel = layerConf().getKernelSize();
|
||||||
int[] strides = layerConf().getStride();
|
int[] strides = layerConf().getStride();
|
||||||
|
@ -158,9 +155,6 @@ public class SubsamplingLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l
|
||||||
|
|
||||||
//subsampling doesn't have weights and thus gradients are not calculated for this layer
|
//subsampling doesn't have weights and thus gradients are not calculated for this layer
|
||||||
//only scale and reshape epsilon
|
//only scale and reshape epsilon
|
||||||
// FIXME: int cast
|
|
||||||
int inputHeight = (int) input().size(-2);
|
|
||||||
int inputWidth = (int) input().size(-1);
|
|
||||||
Gradient retGradient = new DefaultGradient();
|
Gradient retGradient = new DefaultGradient();
|
||||||
|
|
||||||
|
|
||||||
|
@ -231,11 +225,10 @@ public class SubsamplingLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l
|
||||||
|
|
||||||
INDArray input = this.input.castTo(dataType);
|
INDArray input = this.input.castTo(dataType);
|
||||||
|
|
||||||
// FIXME: int cast
|
long miniBatch = input.size(0);
|
||||||
int miniBatch = (int) input.size(0);
|
long inDepth = input.size(1);
|
||||||
int inDepth = (int) input.size(1);
|
int inH = (int)input.size(2);
|
||||||
int inH = (int) input.size(2);
|
int inW = (int)input.size(3);
|
||||||
int inW = (int) input.size(3);
|
|
||||||
|
|
||||||
int[] kernel = layerConf().getKernelSize();
|
int[] kernel = layerConf().getKernelSize();
|
||||||
int[] strides = layerConf().getStride();
|
int[] strides = layerConf().getStride();
|
||||||
|
@ -250,8 +243,8 @@ public class SubsamplingLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l
|
||||||
pad = layerConf().getPadding();
|
pad = layerConf().getPadding();
|
||||||
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation); //Also performs validation
|
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation); //Also performs validation
|
||||||
}
|
}
|
||||||
int outH = outSize[0];
|
long outH = outSize[0];
|
||||||
int outW = outSize[1];
|
long outW = outSize[1];
|
||||||
|
|
||||||
|
|
||||||
if (helper != null && (helperCountFail == 0 || !layerConf().isCudnnAllowFallback())) {
|
if (helper != null && (helperCountFail == 0 || !layerConf().isCudnnAllowFallback())) {
|
||||||
|
@ -278,9 +271,6 @@ public class SubsamplingLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
INDArray output = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), new long[]{miniBatch, inDepth, outH, outW}, 'c');
|
INDArray output = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), new long[]{miniBatch, inDepth, outH, outW}, 'c');
|
||||||
DynamicCustomOp.DynamicCustomOpsBuilder b;
|
DynamicCustomOp.DynamicCustomOpsBuilder b;
|
||||||
int extra = 0;
|
int extra = 0;
|
||||||
|
|
|
@ -65,11 +65,10 @@ public class Upsampling1D extends Upsampling2D {
|
||||||
INDArray originalInput = input;
|
INDArray originalInput = input;
|
||||||
input = input.castTo(dataType).reshape(input.size(0), input.size(1), input.size(2), 1);
|
input = input.castTo(dataType).reshape(input.size(0), input.size(1), input.size(2), 1);
|
||||||
|
|
||||||
// FIXME: int cast
|
long miniBatch = input.size(0);
|
||||||
int miniBatch = (int) input.size(0);
|
long inDepth = input.size(1);
|
||||||
int inDepth = (int) input.size(1);
|
long inH = input.size(2);
|
||||||
int inH = (int) input.size(2);
|
long inW = input.size(3);
|
||||||
int inW = (int) input.size(3);
|
|
||||||
|
|
||||||
|
|
||||||
INDArray outEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), miniBatch * inDepth * inH * inW);
|
INDArray outEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), miniBatch * inDepth * inH * inW);
|
||||||
|
|
|
@ -62,11 +62,10 @@ public class Upsampling2D extends AbstractLayer<org.deeplearning4j.nn.conf.layer
|
||||||
public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
|
public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
|
||||||
assertInputSet(true);
|
assertInputSet(true);
|
||||||
|
|
||||||
// FIXME: int cast
|
long miniBatch = (int) input.size(0);
|
||||||
int miniBatch = (int) input.size(0);
|
long inDepth = (int) input.size(1);
|
||||||
int inDepth = (int) input.size(1);
|
long inH = (int) input.size(2);
|
||||||
int inH = (int) input.size(2);
|
long inW = (int) input.size(3);
|
||||||
int inW = (int) input.size(3);
|
|
||||||
|
|
||||||
INDArray reshapedEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, epsilon.dataType(), new long[]{miniBatch, inDepth, inH, inW}, 'c');
|
INDArray reshapedEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, epsilon.dataType(), new long[]{miniBatch, inDepth, inH, inW}, 'c');
|
||||||
|
|
||||||
|
@ -106,15 +105,14 @@ public class Upsampling2D extends AbstractLayer<org.deeplearning4j.nn.conf.layer
|
||||||
return preOutput;
|
return preOutput;
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME: int cast
|
long miniBatch = (int) input.size(0);
|
||||||
int miniBatch = (int) input.size(0);
|
long inDepth = (int) input.size(1);
|
||||||
int inDepth = (int) input.size(1);
|
long inH = (int) input.size(2);
|
||||||
int inH = (int) input.size(2);
|
long inW = (int) input.size(3);
|
||||||
int inW = (int) input.size(3);
|
|
||||||
|
|
||||||
int[] size = getSize();
|
int[] size = getSize();
|
||||||
int outH = inH * size[0];
|
int outH = (int)inH * size[0];
|
||||||
int outW = inW * size[1];
|
int outW = (int)inW * size[1];
|
||||||
|
|
||||||
INDArray reshapedOutput = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), new long[]{miniBatch, inDepth, outH, outW}, 'c');
|
INDArray reshapedOutput = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), new long[]{miniBatch, inDepth, outH, outW}, 'c');
|
||||||
|
|
||||||
|
|
|
@ -68,22 +68,21 @@ public class Upsampling3D extends AbstractLayer<org.deeplearning4j.nn.conf.layer
|
||||||
assertInputSet(true);
|
assertInputSet(true);
|
||||||
|
|
||||||
boolean ncdhw = layerConf().getDataFormat() == org.deeplearning4j.nn.conf.layers.Convolution3D.DataFormat.NCDHW;
|
boolean ncdhw = layerConf().getDataFormat() == org.deeplearning4j.nn.conf.layers.Convolution3D.DataFormat.NCDHW;
|
||||||
// FIXME: int cast
|
|
||||||
// Assumes NCDHW order
|
// Assumes NCDHW order
|
||||||
int miniBatch = (int) input.size(0);
|
long miniBatch = input.size(0);
|
||||||
int inChannels, inD, inH, inW;
|
long inChannels, inD, inH, inW;
|
||||||
int[] intArgs;
|
int[] intArgs;
|
||||||
if(ncdhw){
|
if(ncdhw){
|
||||||
inChannels = (int) input.size(1);
|
inChannels = input.size(1);
|
||||||
inD = (int) input.size(2);
|
inD = input.size(2);
|
||||||
inH = (int) input.size(3);
|
inH = input.size(3);
|
||||||
inW = (int) input.size(4);
|
inW = input.size(4);
|
||||||
intArgs = new int[] {1}; // 1 is channels first
|
intArgs = new int[] {1}; // 1 is channels first
|
||||||
} else {
|
} else {
|
||||||
inD = (int) input.size(1);
|
inD = input.size(1);
|
||||||
inH = (int) input.size(2);
|
inH = input.size(2);
|
||||||
inW = (int) input.size(3);
|
inW = input.size(3);
|
||||||
inChannels = (int) input.size(4);
|
inChannels = input.size(4);
|
||||||
intArgs = new int[] {0}; // 0 is channels last
|
intArgs = new int[] {0}; // 0 is channels last
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -134,9 +133,8 @@ public class Upsampling3D extends AbstractLayer<org.deeplearning4j.nn.conf.layer
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean ncdhw = layerConf().getDataFormat() == org.deeplearning4j.nn.conf.layers.Convolution3D.DataFormat.NCDHW;
|
boolean ncdhw = layerConf().getDataFormat() == org.deeplearning4j.nn.conf.layers.Convolution3D.DataFormat.NCDHW;
|
||||||
// FIXME: int cast
|
long miniBatch = input.size(0);
|
||||||
int miniBatch = (int) input.size(0);
|
long inChannels, inD, inH, inW;
|
||||||
int inChannels, inD, inH, inW;
|
|
||||||
int[] intArgs;
|
int[] intArgs;
|
||||||
int[] size = getSize();
|
int[] size = getSize();
|
||||||
if(ncdhw){
|
if(ncdhw){
|
||||||
|
|
|
@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ops.custom.ScatterUpdate;
|
import org.nd4j.linalg.api.ops.custom.ScatterUpdate;
|
||||||
|
import org.nd4j.linalg.exception.ND4JArraySizeException;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
@ -64,8 +65,7 @@ public class EmbeddingLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.
|
||||||
INDArray weightGradients = gradientViews.get(DefaultParamInitializer.WEIGHT_KEY);
|
INDArray weightGradients = gradientViews.get(DefaultParamInitializer.WEIGHT_KEY);
|
||||||
weightGradients.assign(0);
|
weightGradients.assign(0);
|
||||||
|
|
||||||
// FIXME: int cast
|
long[] indexes = new long[(int) input.length()];
|
||||||
int[] indexes = new int[(int) input.length()];
|
|
||||||
for (int i = 0; i < indexes.length; i++) {
|
for (int i = 0; i < indexes.length; i++) {
|
||||||
indexes[i] = input.getInt(i, 0);
|
indexes[i] = input.getInt(i, 0);
|
||||||
}
|
}
|
||||||
|
@ -99,7 +99,8 @@ public class EmbeddingLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.
|
||||||
|
|
||||||
val nIn = layerConf().getNIn();
|
val nIn = layerConf().getNIn();
|
||||||
|
|
||||||
// FIXME: int cast
|
if (input.length() > Integer.MAX_VALUE)
|
||||||
|
throw new ND4JArraySizeException();
|
||||||
int[] indexes = new int[(int) input.length()];
|
int[] indexes = new int[(int) input.length()];
|
||||||
for (int i = 0; i < indexes.length; i++) {
|
for (int i = 0; i < indexes.length; i++) {
|
||||||
indexes[i] = input.getInt(i, 0);
|
indexes[i] = input.getInt(i, 0);
|
||||||
|
|
|
@ -16,21 +16,28 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.layers.mkldnn;
|
package org.deeplearning4j.nn.layers.mkldnn;
|
||||||
|
|
||||||
|
import org.deeplearning4j.nn.gradient.DefaultGradient;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper;
|
import org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper;
|
||||||
|
import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer;
|
||||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.OpContext;
|
import org.nd4j.linalg.api.ops.OpContext;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNormDerivative;
|
||||||
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
|
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
|
||||||
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -56,33 +63,59 @@ public class MKLDNNBatchNormHelper implements BatchNormalizationHelper {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, int[] shape, INDArray gamma,
|
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma,
|
||||||
INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr workspaceMgr) {
|
INDArray beta, INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr workspaceMgr) {
|
||||||
//2019-02-14: Backprop disabled pending fixes. https://github.com/deeplearning4j/deeplearning4j/issues/7166
|
if(input.dataType() != DataType.FLOAT)
|
||||||
//Also no MKL-DNN implemented for backprop anyway
|
return null; //MKL-DNN only supports float
|
||||||
/*
|
/*
|
||||||
INDArray[] in = gamma == null ? new INDArray[]{input, mean, var, epsilon} : new INDArray[]{input, mean, var, gamma, beta, epsilon};
|
//TODO FIXME - AB 2019/11/01 - https://github.com/eclipse/deeplearning4j/issues/8335
|
||||||
|
List<INDArray> args = new ArrayList<>();
|
||||||
|
args.add(input);
|
||||||
|
args.add(meanCache);
|
||||||
|
args.add(varCache);
|
||||||
|
args.add(epsilon);
|
||||||
|
if(gamma != null)
|
||||||
|
args.add(gamma.reshape(gamma.length()));
|
||||||
|
if(beta != null)
|
||||||
|
args.add(beta.reshape(beta.length()));
|
||||||
|
|
||||||
INDArray gradAtInput = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), input.shape());
|
|
||||||
|
|
||||||
INDArray[] out = gamma == null ? new INDArray[]{gradAtInput, }
|
DynamicCustomOp op = DynamicCustomOp.builder("batchnorm_bp")
|
||||||
|
.addInputs(args.toArray(new INDArray[0]))
|
||||||
BatchNormDerivative bn = BatchNormDerivative.derivativeBuilder()
|
.addIntegerArguments(
|
||||||
.applyBeta(gamma != null)
|
gamma == null ? 0 : 1, //Apply scale
|
||||||
.applyGamma(gamma != null)
|
beta == null ? 0 : 1, //Apply beta
|
||||||
.axis(new int[]{1}) //4d: is channels: NCHW; 2d: is nIn - axis 1 in both cases
|
1) //Axis (NCHW)
|
||||||
.epsilon(eps)
|
.addFloatingPointArguments(eps)
|
||||||
.inputArrays(in)
|
|
||||||
.outputArrays(new INDArray[]{out})
|
|
||||||
.build();
|
.build();
|
||||||
Nd4j.exec(bn);
|
|
||||||
*/
|
|
||||||
|
|
||||||
|
INDArray epsAtInput = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape());
|
||||||
|
INDArray dLdm = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, meanCache.dataType(), meanCache.shape());
|
||||||
|
INDArray dLdv = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, meanCache.dataType(), meanCache.shape());
|
||||||
|
|
||||||
|
op.setOutputArgument(0, epsAtInput);
|
||||||
|
op.setOutputArgument(1, dLdm);
|
||||||
|
op.setOutputArgument(2, dLdv);
|
||||||
|
if(dGammaView != null) {
|
||||||
|
//Both are always null/not null simultaneously
|
||||||
|
op.setOutputArgument(3, dGammaView.reshape(dGammaView.length()));
|
||||||
|
op.setOutputArgument(4, dBetaView.reshape(dBetaView.length()));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Nd4j.exec(op);
|
||||||
|
|
||||||
|
Gradient g = new DefaultGradient();
|
||||||
|
g.setGradientFor(BatchNormalizationParamInitializer.GAMMA, dGammaView);
|
||||||
|
g.setGradientFor(BatchNormalizationParamInitializer.BETA, dBetaView);
|
||||||
|
|
||||||
|
return new Pair<>(g, epsAtInput);
|
||||||
|
*/
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray preOutput(INDArray x, boolean training, int[] shape, INDArray gamma, INDArray beta, INDArray mean, INDArray var,
|
public INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean, INDArray var,
|
||||||
double decay, double eps, LayerWorkspaceMgr workspaceMgr) {
|
double decay, double eps, LayerWorkspaceMgr workspaceMgr) {
|
||||||
if(x.dataType() != DataType.FLOAT)
|
if(x.dataType() != DataType.FLOAT)
|
||||||
return null; //MKL-DNN only supports float
|
return null; //MKL-DNN only supports float
|
||||||
|
|
|
@ -0,0 +1,168 @@
|
||||||
|
package org.deeplearning4j.nn.layers.mkldnn;
|
||||||
|
|
||||||
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.LSTM;
|
||||||
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.nn.layers.recurrent.FwdPassReturn;
|
||||||
|
import org.deeplearning4j.nn.layers.recurrent.LSTMHelper;
|
||||||
|
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||||
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
|
import org.nd4j.linalg.activations.IActivation;
|
||||||
|
import org.nd4j.linalg.activations.impl.*;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.indexing.BooleanIndexing;
|
||||||
|
import org.nd4j.linalg.indexing.conditions.Conditions;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class MKLDNNLSTMHelper implements LSTMHelper {
|
||||||
|
@Override
|
||||||
|
public boolean checkSupported(IActivation gateActivationFn, IActivation activationFn, boolean hasPeepholeConnections) {
|
||||||
|
//TODO check other activation functions for MKLDNN
|
||||||
|
return gateActivationFn instanceof ActivationSigmoid && activationFn instanceof ActivationTanH && BaseMKLDNNHelper.mklDnnEnabled();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Pair<Gradient, INDArray> backpropGradient(NeuralNetConfiguration conf, IActivation gateActivationFn, INDArray input,
|
||||||
|
INDArray recurrentWeights, INDArray inputWeights, INDArray epsilon, boolean truncatedBPTT,
|
||||||
|
int tbpttBackwardLength, FwdPassReturn fwdPass, boolean forwards, String inputWeightKey,
|
||||||
|
String recurrentWeightKey, String biasWeightKey, Map<String, INDArray> gradientViews,
|
||||||
|
INDArray maskArray, boolean hasPeepholeConnections, LayerWorkspaceMgr workspaceMgr) {
|
||||||
|
//Not yet implemented/supported
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FwdPassReturn activate(Layer layer, NeuralNetConfiguration conf, IActivation gateActivationFn, INDArray input,
|
||||||
|
INDArray recurrentWeights, INDArray inputWeights, INDArray biases, boolean training,
|
||||||
|
INDArray prevOutputActivations, INDArray prevMemCellState, boolean forBackprop, boolean forwards,
|
||||||
|
String inputWeightKey, INDArray maskArray, boolean hasPeepholeConnections, LayerWorkspaceMgr workspaceMgr) {
|
||||||
|
|
||||||
|
/*
|
||||||
|
DL4J data format: [bS, nIn, sL] - dataFormat == 2, directionMode == 0 (forward)
|
||||||
|
Inputs:
|
||||||
|
x = [bS, nIn, sL]
|
||||||
|
Wx = [nIn, 4*nOut]
|
||||||
|
Wr = [nOut, 4*nOut]
|
||||||
|
Wp = [3*nOut] Optional peephole weights
|
||||||
|
b = [4*nOut]
|
||||||
|
seqLen = [bS]
|
||||||
|
initialOut = [bs, nOut]
|
||||||
|
initialCell = [bs, nOut]
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
out = [bS, nOut, sL]
|
||||||
|
outLast = [bs, nOut]
|
||||||
|
cellLast = [bs,nOut]
|
||||||
|
|
||||||
|
Gates order: input, forget, input modulation, output
|
||||||
|
|
||||||
|
|
||||||
|
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
|
||||||
|
const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided
|
||||||
|
const auto hasInitH = B_ARG(2); // indicates whether initial output is provided
|
||||||
|
const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided
|
||||||
|
const auto hasPH = B_ARG(4); // indicates whether peephole connections are present
|
||||||
|
const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1}
|
||||||
|
const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
|
||||||
|
const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
|
||||||
|
*/
|
||||||
|
|
||||||
|
INDArray b1d = biases.reshape(biases.length());
|
||||||
|
INDArray seqLen = null;
|
||||||
|
if(maskArray != null){
|
||||||
|
seqLen = BooleanIndexing.firstIndex(maskArray, Conditions.equals(0), 1); //First 0 along dimension 1 (for [mb, seqLen])
|
||||||
|
}
|
||||||
|
|
||||||
|
List<INDArray> args = new ArrayList<>();
|
||||||
|
args.add(input);
|
||||||
|
args.add(inputWeights);
|
||||||
|
args.add(recurrentWeights);
|
||||||
|
if(hasPeepholeConnections){
|
||||||
|
throw new IllegalStateException("Not yet implemented");
|
||||||
|
}
|
||||||
|
args.add(b1d);
|
||||||
|
if(seqLen != null)
|
||||||
|
args.add(seqLen);
|
||||||
|
if(prevOutputActivations != null)
|
||||||
|
args.add(prevOutputActivations);
|
||||||
|
if(prevMemCellState != null)
|
||||||
|
args.add(prevMemCellState);
|
||||||
|
|
||||||
|
IActivation a = ((LSTM)conf.getLayer()).getActivationFn();
|
||||||
|
|
||||||
|
DynamicCustomOp op = DynamicCustomOp.builder("lstmLayer")
|
||||||
|
.addInputs(args.toArray(new INDArray[0]))
|
||||||
|
.addBooleanArguments(
|
||||||
|
true, //hasBiases
|
||||||
|
seqLen != null, //hasSeqLen
|
||||||
|
prevOutputActivations != null, //hasInitH
|
||||||
|
prevMemCellState != null, //hasInitC
|
||||||
|
hasPeepholeConnections, //hasPh
|
||||||
|
true, //retFullSeq
|
||||||
|
true, //retLastH
|
||||||
|
true //retLastC
|
||||||
|
)
|
||||||
|
.addIntegerArguments(
|
||||||
|
2, //data format: 2 = [bS, nIn, sL]
|
||||||
|
0, //direction: 0 = forward
|
||||||
|
activationToArg(gateActivationFn), //Gate activation
|
||||||
|
activationToArg(a), //Cell state activation
|
||||||
|
activationToArg(a) //Output activation (same as cell in DL4J)
|
||||||
|
)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
List<LongShapeDescriptor> outShapes = op.calculateOutputShape();
|
||||||
|
|
||||||
|
for(LongShapeDescriptor lsd : outShapes){
|
||||||
|
INDArray arr = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, lsd.dataType(), lsd.getShape(), lsd.getOrder());
|
||||||
|
op.addOutputArgument(arr);
|
||||||
|
}
|
||||||
|
|
||||||
|
FwdPassReturn f = new FwdPassReturn();
|
||||||
|
f.fwdPassOutput = op.getOutputArgument(0);
|
||||||
|
f.lastAct = op.getOutputArgument(1);
|
||||||
|
f.lastMemCell = op.getOutputArgument(2);
|
||||||
|
|
||||||
|
return f;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<String, Long> helperMemoryUse() {
|
||||||
|
return Collections.emptyMap();
|
||||||
|
}
|
||||||
|
|
||||||
|
private int activationToArg(IActivation a){
|
||||||
|
//0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
|
||||||
|
if(a instanceof ActivationTanH)
|
||||||
|
return 0;
|
||||||
|
if(a instanceof ActivationReLU)
|
||||||
|
return 1;
|
||||||
|
if(a instanceof ActivationSigmoid)
|
||||||
|
return 2;
|
||||||
|
if(a instanceof ActivationIdentity)
|
||||||
|
return 3;
|
||||||
|
if(a instanceof ActivationLReLU)
|
||||||
|
return 4;
|
||||||
|
if(a instanceof ActivationThresholdedReLU)
|
||||||
|
return 5;
|
||||||
|
if(a instanceof ActivationHardSigmoid)
|
||||||
|
return 7;
|
||||||
|
if(a instanceof ActivationELU)
|
||||||
|
return 8;
|
||||||
|
if(a instanceof ActivationSoftSign)
|
||||||
|
return 9;
|
||||||
|
if(a instanceof ActivationSoftPlus)
|
||||||
|
return 10;
|
||||||
|
throw new IllegalStateException("Unknown or not supported activation function: " + a);
|
||||||
|
}
|
||||||
|
}
|
|
@ -118,6 +118,7 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
|
||||||
INDArray globalVar = params.get(BatchNormalizationParamInitializer.GLOBAL_VAR); //One of log10std will be null depending on config
|
INDArray globalVar = params.get(BatchNormalizationParamInitializer.GLOBAL_VAR); //One of log10std will be null depending on config
|
||||||
INDArray globalLog10Std = params.get(BatchNormalizationParamInitializer.GLOBAL_LOG_STD);
|
INDArray globalLog10Std = params.get(BatchNormalizationParamInitializer.GLOBAL_LOG_STD);
|
||||||
INDArray gamma = null;
|
INDArray gamma = null;
|
||||||
|
INDArray beta = null;
|
||||||
INDArray dGammaView;
|
INDArray dGammaView;
|
||||||
INDArray dBetaView;
|
INDArray dBetaView;
|
||||||
INDArray dGlobalMeanView = gradientViews.get(BatchNormalizationParamInitializer.GLOBAL_MEAN);
|
INDArray dGlobalMeanView = gradientViews.get(BatchNormalizationParamInitializer.GLOBAL_MEAN);
|
||||||
|
@ -129,6 +130,7 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
|
||||||
dBetaView = Nd4j.createUninitialized(dataType, tempShape, 'c');
|
dBetaView = Nd4j.createUninitialized(dataType, tempShape, 'c');
|
||||||
} else {
|
} else {
|
||||||
gamma = getParam(BatchNormalizationParamInitializer.GAMMA);
|
gamma = getParam(BatchNormalizationParamInitializer.GAMMA);
|
||||||
|
beta = getParam(BatchNormalizationParamInitializer.BETA);
|
||||||
dGammaView = gradientViews.get(BatchNormalizationParamInitializer.GAMMA);
|
dGammaView = gradientViews.get(BatchNormalizationParamInitializer.GAMMA);
|
||||||
dBetaView = gradientViews.get(BatchNormalizationParamInitializer.BETA);
|
dBetaView = gradientViews.get(BatchNormalizationParamInitializer.BETA);
|
||||||
}
|
}
|
||||||
|
@ -152,15 +154,14 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
|
||||||
eps = epsilon;
|
eps = epsilon;
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME: int cast
|
|
||||||
Pair<Gradient,INDArray> ret = null;
|
Pair<Gradient,INDArray> ret = null;
|
||||||
try {
|
try {
|
||||||
ret = helper.backpropGradient(in, eps, ArrayUtil.toInts(shape), gamma, dGammaView, dBetaView,
|
ret = helper.backpropGradient(in, eps, shape, gamma, beta, dGammaView, dBetaView,
|
||||||
layerConf.getEps(), workspaceMgr);
|
layerConf.getEps(), workspaceMgr);
|
||||||
} catch (ND4JOpProfilerException e){
|
} catch (ND4JOpProfilerException e){
|
||||||
throw e; //NaN panic etc for debugging
|
throw e; //NaN panic etc for debugging
|
||||||
} catch (Throwable t){
|
} catch (Throwable t){
|
||||||
if(t.getMessage().contains("Failed to allocate")){
|
if(t.getMessage() != null && t.getMessage().contains("Failed to allocate")){
|
||||||
//This is a memory exception - don't fallback to built-in implementation
|
//This is a memory exception - don't fallback to built-in implementation
|
||||||
throw t;
|
throw t;
|
||||||
}
|
}
|
||||||
|
@ -438,7 +439,6 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
|
||||||
//Note that cudnn does not support dense (2d) batch norm case as of v7.1
|
//Note that cudnn does not support dense (2d) batch norm case as of v7.1
|
||||||
double decay = layerConf.getDecay();
|
double decay = layerConf.getDecay();
|
||||||
|
|
||||||
// FIXME: int cast
|
|
||||||
INDArray ret = null;
|
INDArray ret = null;
|
||||||
try {
|
try {
|
||||||
if(globalVarView == null){
|
if(globalVarView == null){
|
||||||
|
@ -448,12 +448,12 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
|
||||||
globalVarView.muli(globalVarView);
|
globalVarView.muli(globalVarView);
|
||||||
}
|
}
|
||||||
|
|
||||||
ret = helper.preOutput(in, training == TrainingMode.TRAIN, ArrayUtil.toInts(shape), gamma, beta, globalMeanView,
|
ret = helper.preOutput(in, training == TrainingMode.TRAIN, shape, gamma, beta, globalMeanView,
|
||||||
globalVarView, decay, layerConf.getEps(), workspaceMgr);
|
globalVarView, decay, layerConf.getEps(), workspaceMgr);
|
||||||
} catch (ND4JOpProfilerException e){
|
} catch (ND4JOpProfilerException e){
|
||||||
throw e; //NaN panic etc for debugging
|
throw e; //NaN panic etc for debugging
|
||||||
} catch (Throwable t) {
|
} catch (Throwable t) {
|
||||||
if(t.getMessage().contains("Failed to allocate")){
|
if(t.getMessage() != null && t.getMessage().contains("Failed to allocate")){
|
||||||
//This is a memory exception - don't fallback to built-in implementation
|
//This is a memory exception - don't fallback to built-in implementation
|
||||||
throw t;
|
throw t;
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,10 +31,10 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
public interface BatchNormalizationHelper extends LayerHelper {
|
public interface BatchNormalizationHelper extends LayerHelper {
|
||||||
boolean checkSupported(double eps, boolean fixedGammaBeta);
|
boolean checkSupported(double eps, boolean fixedGammaBeta);
|
||||||
|
|
||||||
Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, int[] shape, INDArray gamma,
|
Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, INDArray beta,
|
||||||
INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr workspaceMgr);
|
INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr workspaceMgr);
|
||||||
|
|
||||||
INDArray preOutput(INDArray x, boolean training, int[] shape, INDArray gamma, INDArray beta, INDArray mean,
|
INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean,
|
||||||
INDArray var, double decay, double eps, LayerWorkspaceMgr workspaceMgr);
|
INDArray var, double decay, double eps, LayerWorkspaceMgr workspaceMgr);
|
||||||
|
|
||||||
INDArray getMeanCache(DataType dataType);
|
INDArray getMeanCache(DataType dataType);
|
||||||
|
|
|
@ -144,7 +144,7 @@ public class LocalResponseNormalization
|
||||||
} catch (ND4JOpProfilerException e){
|
} catch (ND4JOpProfilerException e){
|
||||||
throw e; //NaN panic etc for debugging
|
throw e; //NaN panic etc for debugging
|
||||||
} catch (Throwable t){
|
} catch (Throwable t){
|
||||||
if(t.getMessage().contains("Failed to allocate")){
|
if(t.getMessage() != null && t.getMessage().contains("Failed to allocate")){
|
||||||
//This is a memory exception - don't fallback to built-in implementation
|
//This is a memory exception - don't fallback to built-in implementation
|
||||||
throw t;
|
throw t;
|
||||||
}
|
}
|
||||||
|
@ -211,7 +211,7 @@ public class LocalResponseNormalization
|
||||||
} catch (ND4JOpProfilerException e){
|
} catch (ND4JOpProfilerException e){
|
||||||
throw e; //NaN panic etc for debugging
|
throw e; //NaN panic etc for debugging
|
||||||
} catch (Throwable t){
|
} catch (Throwable t){
|
||||||
if(t.getMessage().contains("Failed to allocate")){
|
if(t.getMessage() != null && t.getMessage().contains("Failed to allocate")){
|
||||||
//This is a memory exception - don't fallback to built-in implementation
|
//This is a memory exception - don't fallback to built-in implementation
|
||||||
throw t;
|
throw t;
|
||||||
}
|
}
|
||||||
|
|
|
@ -114,10 +114,9 @@ public class Yolo2OutputLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l
|
||||||
double lambdaCoord = layerConf().getLambdaCoord();
|
double lambdaCoord = layerConf().getLambdaCoord();
|
||||||
double lambdaNoObj = layerConf().getLambdaNoObj();
|
double lambdaNoObj = layerConf().getLambdaNoObj();
|
||||||
|
|
||||||
// FIXME: int cast
|
long mb = input.size(0);
|
||||||
int mb = (int) input.size(0);
|
long h = input.size(2);
|
||||||
int h = (int) input.size(2);
|
long w = input.size(3);
|
||||||
int w = (int) input.size(3);
|
|
||||||
int b = (int) layerConf().getBoundingBoxes().size(0);
|
int b = (int) layerConf().getBoundingBoxes().size(0);
|
||||||
int c = (int) labels.size(1)-4;
|
int c = (int) labels.size(1)-4;
|
||||||
|
|
||||||
|
@ -243,12 +242,12 @@ public class Yolo2OutputLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l
|
||||||
|
|
||||||
//Class prediction loss
|
//Class prediction loss
|
||||||
INDArray classPredictionsPreSoftmax2d = inputClassesPreSoftmax.permute(0,1,3,4,2) //[minibatch, b, c, h, w] To [mb, b, h, w, c]
|
INDArray classPredictionsPreSoftmax2d = inputClassesPreSoftmax.permute(0,1,3,4,2) //[minibatch, b, c, h, w] To [mb, b, h, w, c]
|
||||||
.dup('c').reshape('c', new int[]{mb*b*h*w, c});
|
.dup('c').reshape('c', new long[]{mb*b*h*w, c});
|
||||||
INDArray classLabelsBroadcast = Nd4j.createUninitialized(input.dataType(), new long[]{mb, b, c, h, w}, 'c');
|
INDArray classLabelsBroadcast = Nd4j.createUninitialized(input.dataType(), new long[]{mb, b, c, h, w}, 'c');
|
||||||
for(int i=0; i<b; i++ ){
|
for(int i=0; i<b; i++ ){
|
||||||
classLabelsBroadcast.get(all(), point(i), all(), all(), all()).assign(classLabels); //[mb, c, h, w] to [mb, b, c, h, w]
|
classLabelsBroadcast.get(all(), point(i), all(), all(), all()).assign(classLabels); //[mb, c, h, w] to [mb, b, c, h, w]
|
||||||
}
|
}
|
||||||
INDArray classLabels2d = classLabelsBroadcast.permute(0,1,3,4,2).dup('c').reshape('c', new int[]{mb*b*h*w, c});
|
INDArray classLabels2d = classLabelsBroadcast.permute(0,1,3,4,2).dup('c').reshape('c', new long[]{mb*b*h*w, c});
|
||||||
|
|
||||||
//Calculate the loss:
|
//Calculate the loss:
|
||||||
ILossFunction lossConfidence = new LossL2();
|
ILossFunction lossConfidence = new LossL2();
|
||||||
|
@ -297,7 +296,7 @@ public class Yolo2OutputLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l
|
||||||
// ----- Gradient Calculation (specifically: return dL/dIn -----
|
// ----- Gradient Calculation (specifically: return dL/dIn -----
|
||||||
|
|
||||||
INDArray epsOut = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape(), 'c');
|
INDArray epsOut = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape(), 'c');
|
||||||
INDArray epsOut5 = Shape.newShapeNoCopy(epsOut, new int[]{mb, b, 5+c, h, w}, false);
|
INDArray epsOut5 = Shape.newShapeNoCopy(epsOut, new long[]{mb, b, 5+c, h, w}, false);
|
||||||
INDArray epsClassPredictions = epsOut5.get(all(), all(), interval(5, 5+c), all(), all()); //Shape: [mb, b, 5+c, h, w]
|
INDArray epsClassPredictions = epsOut5.get(all(), all(), interval(5, 5+c), all(), all()); //Shape: [mb, b, 5+c, h, w]
|
||||||
INDArray epsXY = epsOut5.get(all(), all(), interval(0,2), all(), all());
|
INDArray epsXY = epsOut5.get(all(), all(), interval(0,2), all(), all());
|
||||||
INDArray epsWH = epsOut5.get(all(), all(), interval(2,4), all(), all());
|
INDArray epsWH = epsOut5.get(all(), all(), interval(2,4), all(), all());
|
||||||
|
@ -426,16 +425,16 @@ public class Yolo2OutputLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l
|
||||||
* @return IOU and gradients
|
* @return IOU and gradients
|
||||||
*/
|
*/
|
||||||
private static IOURet calculateIOULabelPredicted(INDArray labelTL, INDArray labelBR, INDArray predictedWH, INDArray predictedXYinGridBox, INDArray objectPresentMask, INDArray objectPresentMaskBool){
|
private static IOURet calculateIOULabelPredicted(INDArray labelTL, INDArray labelBR, INDArray predictedWH, INDArray predictedXYinGridBox, INDArray objectPresentMask, INDArray objectPresentMaskBool){
|
||||||
// FIXME: int cast
|
|
||||||
int mb = (int) labelTL.size(0);
|
long mb = labelTL.size(0);
|
||||||
int h = (int) labelTL.size(2);
|
long h = labelTL.size(2);
|
||||||
int w = (int) labelTL.size(3);
|
long w = labelTL.size(3);
|
||||||
int b = (int) predictedWH.size(1);
|
long b = predictedWH.size(1);
|
||||||
|
|
||||||
INDArray labelWH = labelBR.sub(labelTL); //4d [mb, 2, H, W], label W/H in terms of number of grid boxes
|
INDArray labelWH = labelBR.sub(labelTL); //4d [mb, 2, H, W], label W/H in terms of number of grid boxes
|
||||||
|
|
||||||
int gridH = (int) labelTL.size(2);
|
long gridH = labelTL.size(2);
|
||||||
int gridW = (int) labelTL.size(3);
|
long gridW = labelTL.size(3);
|
||||||
//Add grid positions to the predicted XY values (to get predicted XY in terms of grid cell units in image,
|
//Add grid positions to the predicted XY values (to get predicted XY in terms of grid cell units in image,
|
||||||
// from (0 to 1 in grid cell) format)
|
// from (0 to 1 in grid cell) format)
|
||||||
INDArray linspaceX = Nd4j.linspace(0, gridW-1, gridW, predictedWH.dataType());
|
INDArray linspaceX = Nd4j.linspace(0, gridW-1, gridW, predictedWH.dataType());
|
||||||
|
|
|
@ -45,12 +45,11 @@ public class YoloUtils {
|
||||||
}
|
}
|
||||||
|
|
||||||
public static INDArray activate(@NonNull INDArray boundingBoxPriors, @NonNull INDArray input, LayerWorkspaceMgr layerWorkspaceMgr){
|
public static INDArray activate(@NonNull INDArray boundingBoxPriors, @NonNull INDArray input, LayerWorkspaceMgr layerWorkspaceMgr){
|
||||||
// FIXME: int cast
|
long mb = input.size(0);
|
||||||
int mb = (int) input.size(0);
|
long h = input.size(2);
|
||||||
int h = (int) input.size(2);
|
long w = input.size(3);
|
||||||
int w = (int) input.size(3);
|
long b = boundingBoxPriors.size(0);
|
||||||
int b = (int) boundingBoxPriors.size(0);
|
long c = input.size(1)/b-5; //input.size(1) == b * (5 + C) -> C = (input.size(1)/b) - 5
|
||||||
int c = (int) (input.size(1)/b)-5; //input.size(1) == b * (5 + C) -> C = (input.size(1)/b) - 5
|
|
||||||
|
|
||||||
INDArray output = layerWorkspaceMgr.create(ArrayType.ACTIVATIONS, input.dataType(), input.shape(), 'c');
|
INDArray output = layerWorkspaceMgr.create(ArrayType.ACTIVATIONS, input.dataType(), input.shape(), 'c');
|
||||||
INDArray output5 = output.reshape('c', mb, b, 5+c, h, w);
|
INDArray output5 = output.reshape('c', mb, b, 5+c, h, w);
|
||||||
|
@ -77,7 +76,7 @@ public class YoloUtils {
|
||||||
//TODO OPTIMIZE?
|
//TODO OPTIMIZE?
|
||||||
INDArray inputClassesPreSoftmax = input5.get(all(), all(), interval(5, 5+c), all(), all()); //Shape: [minibatch, C, H, W]
|
INDArray inputClassesPreSoftmax = input5.get(all(), all(), interval(5, 5+c), all(), all()); //Shape: [minibatch, C, H, W]
|
||||||
INDArray classPredictionsPreSoftmax2d = inputClassesPreSoftmax.permute(0,1,3,4,2) //[minibatch, b, c, h, w] To [mb, b, h, w, c]
|
INDArray classPredictionsPreSoftmax2d = inputClassesPreSoftmax.permute(0,1,3,4,2) //[minibatch, b, c, h, w] To [mb, b, h, w, c]
|
||||||
.dup('c').reshape('c', new int[]{mb*b*h*w, c});
|
.dup('c').reshape('c', new long[]{mb*b*h*w, c});
|
||||||
Transforms.softmax(classPredictionsPreSoftmax2d, false);
|
Transforms.softmax(classPredictionsPreSoftmax2d, false);
|
||||||
INDArray postSoftmax5d = classPredictionsPreSoftmax2d.reshape('c', mb, b, h, w, c ).permute(0, 1, 4, 2, 3);
|
INDArray postSoftmax5d = classPredictionsPreSoftmax2d.reshape('c', mb, b, h, w, c ).permute(0, 1, 4, 2, 3);
|
||||||
|
|
||||||
|
@ -173,13 +172,12 @@ public class YoloUtils {
|
||||||
throw new IllegalStateException("Invalid confidence threshold: must be in range [0,1]. Got: " + confThreshold);
|
throw new IllegalStateException("Invalid confidence threshold: must be in range [0,1]. Got: " + confThreshold);
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME: int cast
|
|
||||||
//Activations format: [mb, 5b+c, h, w]
|
//Activations format: [mb, 5b+c, h, w]
|
||||||
int mb = (int) networkOutput.size(0);
|
long mb = networkOutput.size(0);
|
||||||
int h = (int) networkOutput.size(2);
|
long h = networkOutput.size(2);
|
||||||
int w = (int) networkOutput.size(3);
|
long w = networkOutput.size(3);
|
||||||
int b = (int) boundingBoxPriors.size(0);
|
long b = boundingBoxPriors.size(0);
|
||||||
int c = (int) (networkOutput.size(1)/b)-5; //input.size(1) == b * (5 + C) -> C = (input.size(1)/b) - 5
|
long c = (networkOutput.size(1)/b)-5; //input.size(1) == b * (5 + C) -> C = (input.size(1)/b) - 5
|
||||||
|
|
||||||
//Reshape from [minibatch, B*(5+C), H, W] to [minibatch, B, 5+C, H, W] to [minibatch, B, 5, H, W]
|
//Reshape from [minibatch, B*(5+C), H, W] to [minibatch, B, 5+C, H, W] to [minibatch, B, 5, H, W]
|
||||||
INDArray output5 = networkOutput.dup('c').reshape(mb, b, 5+c, h, w);
|
INDArray output5 = networkOutput.dup('c').reshape(mb, b, 5+c, h, w);
|
||||||
|
|
|
@ -22,6 +22,8 @@ import org.deeplearning4j.nn.conf.CacheMode;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.nn.layers.LayerHelper;
|
import org.deeplearning4j.nn.layers.LayerHelper;
|
||||||
|
import org.deeplearning4j.nn.layers.mkldnn.BaseMKLDNNHelper;
|
||||||
|
import org.deeplearning4j.nn.layers.mkldnn.MKLDNNLSTMHelper;
|
||||||
import org.deeplearning4j.nn.params.LSTMParamInitializer;
|
import org.deeplearning4j.nn.params.LSTMParamInitializer;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
@ -73,6 +75,16 @@ public class LSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.L
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
/*
|
||||||
|
//Disabled pending: https://github.com/eclipse/deeplearning4j/issues/8331
|
||||||
|
else if ("CPU".equalsIgnoreCase(backend) && BaseMKLDNNHelper.mklDnnEnabled()){
|
||||||
|
helper = new MKLDNNLSTMHelper();
|
||||||
|
log.debug("MKLDNNLSTMHelper successfully initialized");
|
||||||
|
if (!helper.checkSupported(layerConf().getGateActivationFn(), layerConf().getActivationFn(), false)) {
|
||||||
|
helper = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -40,6 +40,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.same.TimesOneMinus;
|
import org.nd4j.linalg.api.ops.impl.transforms.same.TimesOneMinus;
|
||||||
import org.nd4j.linalg.api.shape.Shape;
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
|
import org.nd4j.linalg.exception.ND4JArraySizeException;
|
||||||
import org.nd4j.linalg.exception.ND4JOpProfilerException;
|
import org.nd4j.linalg.exception.ND4JOpProfilerException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
|
@ -113,7 +114,9 @@ public class LSTMHelpers {
|
||||||
|
|
||||||
input = input.castTo(inputWeights.dataType()); //No-op if already correct dtype
|
input = input.castTo(inputWeights.dataType()); //No-op if already correct dtype
|
||||||
|
|
||||||
// FIXME
|
if ((!is2dInput && (input.size(2) > Integer.MAX_VALUE)) ||
|
||||||
|
recurrentWeights.size(0) > Integer.MAX_VALUE || input.size(0) > Integer.MAX_VALUE)
|
||||||
|
throw new ND4JArraySizeException();
|
||||||
int timeSeriesLength = (int) (is2dInput ? 1 : input.size(2));
|
int timeSeriesLength = (int) (is2dInput ? 1 : input.size(2));
|
||||||
int hiddenLayerSize = (int) recurrentWeights.size(0);
|
int hiddenLayerSize = (int) recurrentWeights.size(0);
|
||||||
int miniBatchSize = (int) input.size(0);
|
int miniBatchSize = (int) input.size(0);
|
||||||
|
@ -550,7 +553,8 @@ public class LSTMHelpers {
|
||||||
for (long iTimeIndex = timeSeriesLength - 1; iTimeIndex >= endIdx; iTimeIndex--) {
|
for (long iTimeIndex = timeSeriesLength - 1; iTimeIndex >= endIdx; iTimeIndex--) {
|
||||||
try(MemoryWorkspace ws = workspaceMgr.notifyScopeEntered(ArrayType.RNN_BP_LOOP_WORKING_MEM)) {
|
try(MemoryWorkspace ws = workspaceMgr.notifyScopeEntered(ArrayType.RNN_BP_LOOP_WORKING_MEM)) {
|
||||||
|
|
||||||
// FIXME: int cast
|
if (iTimeIndex > Integer.MAX_VALUE)
|
||||||
|
throw new ND4JArraySizeException();
|
||||||
int time = (int) iTimeIndex;
|
int time = (int) iTimeIndex;
|
||||||
int inext = 1;
|
int inext = 1;
|
||||||
|
|
||||||
|
@ -574,8 +578,6 @@ public class LSTMHelpers {
|
||||||
(iTimeIndex == 0 ? fwdPass.prevAct : fwdPass.fwdPassOutputAsArrays[(int) (time - inext)]);
|
(iTimeIndex == 0 ? fwdPass.prevAct : fwdPass.fwdPassOutputAsArrays[(int) (time - inext)]);
|
||||||
INDArray currMemCellState = fwdPass.memCellState[(int) time];
|
INDArray currMemCellState = fwdPass.memCellState[(int) time];
|
||||||
|
|
||||||
|
|
||||||
// FIXME: int cast
|
|
||||||
//LSTM unit output errors (dL/d(a_out)); not to be confused with \delta=dL/d(z_out)
|
//LSTM unit output errors (dL/d(a_out)); not to be confused with \delta=dL/d(z_out)
|
||||||
INDArray epsilonSlice = (is2dInput ? epsilon : epsilon.tensorAlongDimension((int) time, 1, 0)); //(w^{L+1}*(delta^{(L+1)t})^T)^T or equiv.
|
INDArray epsilonSlice = (is2dInput ? epsilon : epsilon.tensorAlongDimension((int) time, 1, 0)); //(w^{L+1}*(delta^{(L+1)t})^T)^T or equiv.
|
||||||
|
|
||||||
|
|
|
@ -89,8 +89,7 @@ public class RnnLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.Rn
|
||||||
ILossFunction lossFunction = layerConf().getLossFn();
|
ILossFunction lossFunction = layerConf().getLossFn();
|
||||||
INDArray delta2d = lossFunction.computeGradient(labels2d, input2d.dup(input2d.ordering()), layerConf().getActivationFn(), maskReshaped);
|
INDArray delta2d = lossFunction.computeGradient(labels2d, input2d.dup(input2d.ordering()), layerConf().getActivationFn(), maskReshaped);
|
||||||
|
|
||||||
// FIXME: int cast
|
INDArray delta3d = TimeSeriesUtils.reshape2dTo3d(delta2d, input.size(0), workspaceMgr, ArrayType.ACTIVATION_GRAD);
|
||||||
INDArray delta3d = TimeSeriesUtils.reshape2dTo3d(delta2d, (int) input.size(0), workspaceMgr, ArrayType.ACTIVATION_GRAD);
|
|
||||||
|
|
||||||
// grab the empty gradient
|
// grab the empty gradient
|
||||||
Gradient gradient = new DefaultGradient();
|
Gradient gradient = new DefaultGradient();
|
||||||
|
@ -119,7 +118,6 @@ public class RnnLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.Rn
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int numLabels() {
|
public int numLabels() {
|
||||||
// FIXME: int cast
|
|
||||||
return (int) labels.size(1);
|
return (int) labels.size(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -167,7 +165,7 @@ public class RnnLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.Rn
|
||||||
|
|
||||||
INDArray as2d = TimeSeriesUtils.reshape3dTo2d(input);
|
INDArray as2d = TimeSeriesUtils.reshape3dTo2d(input);
|
||||||
INDArray out2d = layerConf().getActivationFn().getActivation(workspaceMgr.dup(ArrayType.ACTIVATIONS, as2d, as2d.ordering()), training);
|
INDArray out2d = layerConf().getActivationFn().getActivation(workspaceMgr.dup(ArrayType.ACTIVATIONS, as2d, as2d.ordering()), training);
|
||||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, TimeSeriesUtils.reshape2dTo3d(out2d, (int)input.size(0), workspaceMgr, ArrayType.ACTIVATIONS));
|
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, TimeSeriesUtils.reshape2dTo3d(out2d, input.size(0), workspaceMgr, ArrayType.ACTIVATIONS));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -254,8 +252,7 @@ public class RnnLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.Rn
|
||||||
//scoreArray: shape [minibatch*timeSeriesLength, 1]
|
//scoreArray: shape [minibatch*timeSeriesLength, 1]
|
||||||
//Reshape it to [minibatch, timeSeriesLength] then sum over time step
|
//Reshape it to [minibatch, timeSeriesLength] then sum over time step
|
||||||
|
|
||||||
// FIXME: int cast
|
INDArray scoreArrayTs = TimeSeriesUtils.reshapeVectorToTimeSeriesMask(scoreArray, (int)input.size(0));
|
||||||
INDArray scoreArrayTs = TimeSeriesUtils.reshapeVectorToTimeSeriesMask(scoreArray, (int) input.size(0));
|
|
||||||
INDArray summedScores = scoreArrayTs.sum(1);
|
INDArray summedScores = scoreArrayTs.sum(1);
|
||||||
|
|
||||||
if (fullNetRegTerm != 0.0) {
|
if (fullNetRegTerm != 0.0) {
|
||||||
|
|
|
@ -70,8 +70,7 @@ public class RnnOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.l
|
||||||
this.input = inputTemp;
|
this.input = inputTemp;
|
||||||
INDArray epsilon2d = gradAndEpsilonNext.getSecond();
|
INDArray epsilon2d = gradAndEpsilonNext.getSecond();
|
||||||
|
|
||||||
// FIXME: int cast
|
INDArray epsilon3d = TimeSeriesUtils.reshape2dTo3d(epsilon2d, input.size(0), workspaceMgr, ArrayType.ACTIVATION_GRAD);
|
||||||
INDArray epsilon3d = TimeSeriesUtils.reshape2dTo3d(epsilon2d, (int) input.size(0), workspaceMgr, ArrayType.ACTIVATION_GRAD);
|
|
||||||
|
|
||||||
weightNoiseParams.clear();
|
weightNoiseParams.clear();
|
||||||
|
|
||||||
|
@ -145,8 +144,7 @@ public class RnnOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.l
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME: int cast
|
return TimeSeriesUtils.reshape2dTo3d(act2d, input.size(0), workspaceMgr, ArrayType.ACTIVATIONS);
|
||||||
return TimeSeriesUtils.reshape2dTo3d(act2d, (int) input.size(0), workspaceMgr, ArrayType.ACTIVATIONS);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -205,8 +203,7 @@ public class RnnOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.l
|
||||||
//scoreArray: shape [minibatch*timeSeriesLength, 1]
|
//scoreArray: shape [minibatch*timeSeriesLength, 1]
|
||||||
//Reshape it to [minibatch, timeSeriesLength] then sum over time step
|
//Reshape it to [minibatch, timeSeriesLength] then sum over time step
|
||||||
|
|
||||||
// FIXME: int cast
|
INDArray scoreArrayTs = TimeSeriesUtils.reshapeVectorToTimeSeriesMask(scoreArray, (int)input.size(0));
|
||||||
INDArray scoreArrayTs = TimeSeriesUtils.reshapeVectorToTimeSeriesMask(scoreArray, (int) input.size(0));
|
|
||||||
INDArray summedScores = scoreArrayTs.sum(true, 1);
|
INDArray summedScores = scoreArrayTs.sum(true, 1);
|
||||||
|
|
||||||
if (fullNetRegTerm != 0.0) {
|
if (fullNetRegTerm != 0.0) {
|
||||||
|
|
|
@ -0,0 +1,68 @@
|
||||||
|
package org.deeplearning4j.nn.layers.samediff;
|
||||||
|
|
||||||
|
import org.nd4j.autodiff.samediff.internal.memory.AbstractMemoryMgr;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
|
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A SameDiff {@link org.nd4j.autodiff.samediff.internal.SessionMemMgr} that uses DL4J workspaces for memory management.
|
||||||
|
* Any op outputs are allocated in the output workspace if they are returned to the layer; otherwise they are placed in
|
||||||
|
* the DL4J working memory workspace
|
||||||
|
*
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
public class DL4JSameDiffMemoryMgr extends AbstractMemoryMgr {
|
||||||
|
|
||||||
|
private final String workingMemoryWs;
|
||||||
|
private final String outputWs;
|
||||||
|
private final WorkspaceConfiguration confWorking;
|
||||||
|
private final WorkspaceConfiguration confOutput;
|
||||||
|
|
||||||
|
//Note: if the working memory or output workspace names are null -> detached memory
|
||||||
|
public DL4JSameDiffMemoryMgr(String workingMemoryWs, String outputWs, WorkspaceConfiguration confWorking,
|
||||||
|
WorkspaceConfiguration confOutput){
|
||||||
|
this.workingMemoryWs = workingMemoryWs;
|
||||||
|
this.outputWs = outputWs;
|
||||||
|
this.confWorking = confWorking;
|
||||||
|
this.confOutput = confOutput;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray allocate(boolean detached, DataType dataType, long... shape) {
|
||||||
|
String wsName = detached ? outputWs : workingMemoryWs;
|
||||||
|
WorkspaceConfiguration wsConf = detached ? confOutput : confWorking;
|
||||||
|
|
||||||
|
if(wsName == null){
|
||||||
|
//Scoped out
|
||||||
|
INDArray ret = Nd4j.createUninitializedDetached(dataType, shape);
|
||||||
|
Preconditions.checkState(!ret.isAttached(), "Returned array should be detached");
|
||||||
|
return ret;
|
||||||
|
} else {
|
||||||
|
MemoryWorkspace ws = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(wsConf, wsName);
|
||||||
|
try (MemoryWorkspace mw = ws.notifyScopeBorrowed()) {
|
||||||
|
return Nd4j.createUninitialized(dataType, shape);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray allocate(boolean detached, LongShapeDescriptor descriptor) {
|
||||||
|
return allocate(detached, descriptor.dataType(), descriptor.getShape());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void release(INDArray array) {
|
||||||
|
//No-op - DL4J workspaces handles this
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
//No-op - DL4J workspaces handles this
|
||||||
|
}
|
||||||
|
}
|
|
@ -31,9 +31,12 @@ import org.deeplearning4j.nn.workspace.ArrayType;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.autodiff.samediff.internal.InferenceSession;
|
||||||
|
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
|
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
|
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -95,9 +98,10 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
||||||
@Override
|
@Override
|
||||||
public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr) {
|
public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr) {
|
||||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||||
if(sameDiff == null){
|
if (sameDiff == null) {
|
||||||
doInit();
|
doInit();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Map<String,INDArray> phMap = new HashMap<>();
|
Map<String,INDArray> phMap = new HashMap<>();
|
||||||
config.validateInput(inputs);
|
config.validateInput(inputs);
|
||||||
|
@ -112,6 +116,25 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//Configure memory management for SameDiff instance - use DL4J workspaces
|
||||||
|
String wsNameWorking = workspaceMgr.getWorkspaceName(ArrayType.FF_WORKING_MEM);
|
||||||
|
String wsNameOutput = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATIONS);
|
||||||
|
WorkspaceConfiguration confWorking = workspaceMgr.getConfiguration(ArrayType.FF_WORKING_MEM);
|
||||||
|
WorkspaceConfiguration confOutput = workspaceMgr.getConfiguration(ArrayType.ACTIVATIONS);
|
||||||
|
boolean actScopedOut = workspaceMgr.isScopedOut(ArrayType.ACTIVATIONS);
|
||||||
|
Preconditions.checkState(actScopedOut || wsNameOutput != null, "Activations must have a workspace or must be scoped out");
|
||||||
|
SessionMemMgr mmgr = new DL4JSameDiffMemoryMgr(wsNameWorking, wsNameOutput, confWorking, confOutput);
|
||||||
|
|
||||||
|
InferenceSession is = sameDiff.getSessions().get(Thread.currentThread().getId());
|
||||||
|
if(is == null){
|
||||||
|
is = new InferenceSession(sameDiff);
|
||||||
|
sameDiff.getSessions().put(Thread.currentThread().getId(), is);
|
||||||
|
}
|
||||||
|
is.setMmgr(mmgr);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if(paramTable != null && paramTable.size() > 0) {
|
if(paramTable != null && paramTable.size() > 0) {
|
||||||
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
||||||
//TODO Find a more efficient solution for this
|
//TODO Find a more efficient solution for this
|
||||||
|
@ -122,22 +145,29 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
||||||
}
|
}
|
||||||
INDArray result = sameDiff.outputSingle(phMap, outputKey);
|
INDArray result = sameDiff.outputSingle(phMap, outputKey);
|
||||||
|
|
||||||
|
//Edge case: "vertex" is just an identity activation, for example
|
||||||
|
//TODO there may be a cleaner way to do this...
|
||||||
|
if(!actScopedOut && !result.data().getParentWorkspace().getId().equals(wsNameOutput)){
|
||||||
|
result = workspaceMgr.dup(ArrayType.ACTIVATIONS, result);
|
||||||
|
} else if(actScopedOut && result.isAttached()){
|
||||||
|
result = result.detach();
|
||||||
|
}
|
||||||
|
|
||||||
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
||||||
sameDiff.clearPlaceholders(true);
|
sameDiff.clearPlaceholders(true);
|
||||||
sameDiff.clearOpInputs();
|
sameDiff.clearOpInputs();
|
||||||
return workspaceMgr.dup(ArrayType.ACTIVATIONS, result);
|
return workspaceMgr.dup(ArrayType.ACTIVATIONS, result);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Pair<Gradient, INDArray[]> doBackward(boolean tbptt, LayerWorkspaceMgr workspaceMgr) {
|
public Pair<Gradient, INDArray[]> doBackward(boolean tbptt, LayerWorkspaceMgr workspaceMgr) {
|
||||||
Gradient g = new DefaultGradient();
|
Gradient g = new DefaultGradient();
|
||||||
|
|
||||||
INDArray[] dLdIns;
|
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
|
if (sameDiff == null) {
|
||||||
if(sameDiff == null){
|
|
||||||
doInit();
|
doInit();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
List<String> inputNames = config.getVertexParams().getInputs();
|
List<String> inputNames = config.getVertexParams().getInputs();
|
||||||
if(!sameDiff.hasGradientFunction()) {
|
if(!sameDiff.hasGradientFunction()) {
|
||||||
|
@ -146,6 +176,24 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
||||||
sameDiff.createGradFunction(inArr);
|
sameDiff.createGradFunction(inArr);
|
||||||
}
|
}
|
||||||
config.validateInput(inputs);
|
config.validateInput(inputs);
|
||||||
|
|
||||||
|
//Configure memory management for SameDiff instance - use DL4J workspaces
|
||||||
|
Map<Long,InferenceSession> sessionMap = sameDiff.getFunction("grad").getSessions();
|
||||||
|
if(!sessionMap.containsKey(Thread.currentThread().getId())){
|
||||||
|
sessionMap.put(Thread.currentThread().getId(), new InferenceSession(sameDiff.getFunction("grad")));
|
||||||
|
}
|
||||||
|
String wsNameWorking = workspaceMgr.getWorkspaceName(ArrayType.BP_WORKING_MEM);
|
||||||
|
String wsNameActGrad = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATION_GRAD);
|
||||||
|
WorkspaceConfiguration confWorking = workspaceMgr.getConfiguration(ArrayType.BP_WORKING_MEM);
|
||||||
|
WorkspaceConfiguration confOutput = workspaceMgr.getConfiguration(ArrayType.ACTIVATION_GRAD);
|
||||||
|
|
||||||
|
boolean actGradScopedOut = workspaceMgr.isScopedOut(ArrayType.ACTIVATION_GRAD);
|
||||||
|
Preconditions.checkState(actGradScopedOut || wsNameActGrad != null, "Activation gradients must have a workspace or be scoped out");
|
||||||
|
SessionMemMgr mmgr = new DL4JSameDiffMemoryMgr(wsNameWorking, wsNameActGrad, confWorking, confOutput);
|
||||||
|
sessionMap.get(Thread.currentThread().getId()).setMmgr(mmgr);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Map<String,INDArray> phMap = new HashMap<>();
|
Map<String,INDArray> phMap = new HashMap<>();
|
||||||
List<String> inputs = config.getVertexParams().getInputs();
|
List<String> inputs = config.getVertexParams().getInputs();
|
||||||
int i=0;
|
int i=0;
|
||||||
|
@ -167,41 +215,43 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
||||||
|
|
||||||
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
||||||
//TODO Find a more efficient solution for this
|
//TODO Find a more efficient solution for this
|
||||||
|
List<String> required = new ArrayList<>(inputNames.size()); //Ensure that the input placeholder gradients are calculated
|
||||||
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
|
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
|
||||||
INDArray arr = e.getValue();
|
INDArray arr = e.getValue();
|
||||||
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
|
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
|
||||||
}
|
}
|
||||||
|
|
||||||
List<String> required = new ArrayList<>(inputNames.size()); //Ensure that the input placeholder gradients are calculated
|
required.addAll(paramTable.keySet());
|
||||||
for(String s : inputNames){
|
required.addAll(inputNames);
|
||||||
required.add(sameDiff.getVariable(s).gradient().getVarName());
|
|
||||||
}
|
Map<String,INDArray> gradsMap = sameDiff.calculateGradients(phMap, required);
|
||||||
sameDiff.execBackwards(phMap, required);
|
|
||||||
for(String s : paramTable.keySet() ){
|
for(String s : paramTable.keySet() ){
|
||||||
INDArray sdGrad = sameDiff.grad(s).getArr();
|
INDArray sdGrad = gradsMap.get(s);
|
||||||
INDArray dl4jGrad = gradTable.get(s);
|
INDArray dl4jGrad = gradTable.get(s);
|
||||||
dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS
|
dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS
|
||||||
g.gradientForVariable().put(s, dl4jGrad);
|
g.gradientForVariable().put(s, dl4jGrad);
|
||||||
}
|
}
|
||||||
|
|
||||||
dLdIns = new INDArray[inputs.size()];
|
INDArray[] dLdIns = new INDArray[inputs.size()];
|
||||||
String fnName = fn.getGradPlaceholderName();
|
String fnName = fn.getGradPlaceholderName();
|
||||||
for(int j=0; j<inputs.size(); j++ ){
|
for(int j=0; j<inputs.size(); j++ ){
|
||||||
String name = inputs.get(j);
|
String name = inputs.get(j);
|
||||||
dLdIns[j] = sameDiff.grad(name).getArr();
|
dLdIns[j] = sameDiff.grad(name).getArr();
|
||||||
|
|
||||||
String gradName = sameDiff.grad(inputNames.get(j)).getVarName();
|
String gradName = sameDiff.grad(inputNames.get(j)).name();
|
||||||
if(dLdIns[j] == null && fnName.equals(gradName)){
|
if(dLdIns[j] == null && fnName.equals(gradName)){
|
||||||
//Edge case with lambda vertices like identity: SameDiff doesn't store the placeholders
|
//Edge case with lambda vertices like identity: SameDiff doesn't store the placeholders
|
||||||
// So, this getArr() can be trying to get placeholder from SameDiff instance, when it's available here
|
// So, this getArr() can be trying to get placeholder from SameDiff instance, when it's available here
|
||||||
dLdIns[j] = epsilon;
|
dLdIns[j] = epsilon;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//TODO optimize
|
//Edge case: "vertex" is just an identity activation, for example
|
||||||
for( int i=0; i<dLdIns.length; i++ ){
|
//TODO there may be a cleaner way to do this...
|
||||||
dLdIns[i] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIns[i]);
|
if(!actGradScopedOut && !dLdIns[j].data().getParentWorkspace().getId().equals(wsNameActGrad)){
|
||||||
|
dLdIns[j] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIns[j]);
|
||||||
|
} else if(actGradScopedOut && dLdIns[j].isAttached()){
|
||||||
|
dLdIns[j] = dLdIns[j].detach();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
||||||
|
@ -264,7 +314,7 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
||||||
fn = sameDiff.f().externalErrors(layerOutput);
|
fn = sameDiff.f().externalErrors(layerOutput);
|
||||||
fn.outputVariable();
|
fn.outputVariable();
|
||||||
|
|
||||||
this.outputKey = outputVar.getVarName();
|
this.outputKey = outputVar.name();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -26,9 +26,12 @@ import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.nn.layers.AbstractLayer;
|
import org.deeplearning4j.nn.layers.AbstractLayer;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.autodiff.samediff.internal.InferenceSession;
|
||||||
|
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
|
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
|
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -81,9 +84,10 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
||||||
assertInputSet(false);
|
assertInputSet(false);
|
||||||
|
|
||||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||||
if(sameDiff == null){
|
if (sameDiff == null) {
|
||||||
doInit();
|
doInit();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
|
org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
|
||||||
bl.validateInput(input);
|
bl.validateInput(input);
|
||||||
|
@ -103,15 +107,39 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
||||||
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
|
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//Configure memory management for SameDiff instance - use DL4J workspaces
|
||||||
|
String wsNameWorking = workspaceMgr.getWorkspaceName(ArrayType.FF_WORKING_MEM);
|
||||||
|
String wsNameOutput = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATIONS);
|
||||||
|
WorkspaceConfiguration confWorking = workspaceMgr.getConfiguration(ArrayType.FF_WORKING_MEM);
|
||||||
|
WorkspaceConfiguration confOutput = workspaceMgr.getConfiguration(ArrayType.ACTIVATIONS);
|
||||||
|
boolean actScopedOut = workspaceMgr.isScopedOut(ArrayType.ACTIVATIONS);
|
||||||
|
Preconditions.checkState(actScopedOut || wsNameOutput != null, "Activations must have a workspace or must be scoped out");
|
||||||
|
SessionMemMgr mmgr = new DL4JSameDiffMemoryMgr(wsNameWorking, wsNameOutput, confWorking, confOutput);
|
||||||
|
|
||||||
|
InferenceSession is = sameDiff.getSessions().get(Thread.currentThread().getId());
|
||||||
|
if(is == null){
|
||||||
|
is = new InferenceSession(sameDiff);
|
||||||
|
sameDiff.getSessions().put(Thread.currentThread().getId(), is);
|
||||||
|
}
|
||||||
|
is.setMmgr(mmgr);
|
||||||
|
|
||||||
Map<String,INDArray> out = sameDiff.output(phMap, outputKey);
|
Map<String,INDArray> out = sameDiff.output(phMap, outputKey);
|
||||||
INDArray result = out.get(outputKey);
|
INDArray result = out.get(outputKey);
|
||||||
|
|
||||||
|
//Edge case - identity activation
|
||||||
|
//TODO there may be a cleaner way to do this...
|
||||||
|
if(!actScopedOut && !result.data().getParentWorkspace().getId().equals(wsNameOutput)){
|
||||||
|
result = workspaceMgr.dup(ArrayType.ACTIVATIONS, result);
|
||||||
|
} else if(actScopedOut && result.isAttached()){
|
||||||
|
result = result.detach();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
||||||
sameDiff.clearPlaceholders(true);
|
sameDiff.clearPlaceholders(true);
|
||||||
sameDiff.clearOpInputs();
|
sameDiff.clearOpInputs();
|
||||||
|
|
||||||
return workspaceMgr.dup(ArrayType.ACTIVATIONS, result);
|
return result;
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -122,14 +150,31 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
||||||
Gradient g = new DefaultGradient();
|
Gradient g = new DefaultGradient();
|
||||||
|
|
||||||
INDArray dLdIn;
|
INDArray dLdIn;
|
||||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
|
|
||||||
if(sameDiff == null){
|
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||||
|
if (sameDiff == null) {
|
||||||
doInit();
|
doInit();
|
||||||
}
|
}
|
||||||
if(!sameDiff.hasGradientFunction()) {
|
if (!sameDiff.hasGradientFunction()) {
|
||||||
//Create when scoped out, to ensure any arrays are not in WS
|
//Create when scoped out, to ensure any arrays are not in WS
|
||||||
sameDiff.createGradFunction(INPUT_KEY);
|
sameDiff.createGradFunction(INPUT_KEY);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
//Configure memory management for SameDiff instance - use DL4J workspaces
|
||||||
|
Map<Long,InferenceSession> sessionMap = sameDiff.getFunction("grad").getSessions();
|
||||||
|
if(!sessionMap.containsKey(Thread.currentThread().getId())){
|
||||||
|
sessionMap.put(Thread.currentThread().getId(), new InferenceSession(sameDiff.getFunction("grad")));
|
||||||
|
}
|
||||||
|
String wsNameWorking = workspaceMgr.getWorkspaceName(ArrayType.BP_WORKING_MEM);
|
||||||
|
String wsNameActGrad = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATION_GRAD);
|
||||||
|
WorkspaceConfiguration confWorking = workspaceMgr.getConfiguration(ArrayType.BP_WORKING_MEM);
|
||||||
|
WorkspaceConfiguration confOutput = workspaceMgr.getConfiguration(ArrayType.ACTIVATION_GRAD);
|
||||||
|
|
||||||
|
boolean actGradScopedOut = workspaceMgr.isScopedOut(ArrayType.ACTIVATION_GRAD);
|
||||||
|
Preconditions.checkState(actGradScopedOut || wsNameActGrad != null, "Activation gradients must have a workspace or be scoped out");
|
||||||
|
SessionMemMgr mmgr = new DL4JSameDiffMemoryMgr(wsNameWorking, wsNameActGrad, confWorking, confOutput);
|
||||||
|
sessionMap.get(Thread.currentThread().getId()).setMmgr(mmgr);
|
||||||
|
|
||||||
|
|
||||||
org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
|
org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
|
||||||
bl.validateInput(input);
|
bl.validateInput(input);
|
||||||
|
@ -151,34 +196,26 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
||||||
}
|
}
|
||||||
|
|
||||||
List<String> requiredGrads = new ArrayList<>(paramTable.size() + 1);
|
List<String> requiredGrads = new ArrayList<>(paramTable.size() + 1);
|
||||||
requiredGrads.add(sameDiff.grad(INPUT_KEY).getVarName());
|
requiredGrads.add(INPUT_KEY);
|
||||||
for(String s : paramTable.keySet()){
|
requiredGrads.addAll(paramTable.keySet());
|
||||||
requiredGrads.add(sameDiff.grad(s).getVarName());
|
|
||||||
}
|
|
||||||
|
|
||||||
sameDiff.execBackwards(phMap, requiredGrads);
|
Map<String,INDArray> m = sameDiff.calculateGradients(phMap, requiredGrads);
|
||||||
for(String s : paramTable.keySet() ){
|
for(String s : paramTable.keySet() ){
|
||||||
INDArray sdGrad = sameDiff.grad(s).getArr();
|
INDArray sdGrad = m.get(s);
|
||||||
INDArray dl4jGrad = gradTable.get(s);
|
INDArray dl4jGrad = gradTable.get(s);
|
||||||
dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS
|
dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS
|
||||||
g.gradientForVariable().put(s, dl4jGrad);
|
g.gradientForVariable().put(s, dl4jGrad);
|
||||||
}
|
}
|
||||||
|
|
||||||
SDVariable v = sameDiff.grad(INPUT_KEY);
|
dLdIn = m.get(INPUT_KEY);
|
||||||
dLdIn = v.getArr();
|
|
||||||
|
|
||||||
if(dLdIn == null && fn.getGradPlaceholderName().equals(v.getVarName())){
|
|
||||||
//Edge case with lambda layers like identity: SameDiff doesn't store the placeholders
|
|
||||||
// So, this getArr() can be trying to get placeholder from SameDiff instance, when it's available here
|
|
||||||
dLdIn = epsilon;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
||||||
sameDiff.clearPlaceholders(true);
|
sameDiff.clearPlaceholders(true);
|
||||||
sameDiff.clearOpInputs();
|
sameDiff.clearOpInputs();
|
||||||
|
|
||||||
return new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS
|
Pair<Gradient, INDArray> ret = new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**Returns the parameters of the neural network as a flattened row vector
|
/**Returns the parameters of the neural network as a flattened row vector
|
||||||
|
@ -291,7 +328,7 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
||||||
fn = sameDiff.f().externalErrors(layerOutput);
|
fn = sameDiff.f().externalErrors(layerOutput);
|
||||||
fn.outputVariable();
|
fn.outputVariable();
|
||||||
|
|
||||||
this.outputKey = outputVar.getVarName();
|
this.outputKey = outputVar.name();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -29,9 +29,12 @@ import org.deeplearning4j.nn.workspace.ArrayType;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.autodiff.samediff.internal.InferenceSession;
|
||||||
|
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
|
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
|
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
|
||||||
import org.nd4j.linalg.dataset.api.DataSet;
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
@ -95,9 +98,28 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
||||||
|
|
||||||
//TODO optimize
|
//TODO optimize
|
||||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||||
if(sameDiff == null){
|
if (sameDiff == null) {
|
||||||
doInit();
|
doInit();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//Configure memory management for SameDiff instance - use DL4J workspaces
|
||||||
|
String wsNameWorking = workspaceMgr.getWorkspaceName(ArrayType.FF_WORKING_MEM);
|
||||||
|
String wsNameOutput = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATIONS);
|
||||||
|
WorkspaceConfiguration confWorking = workspaceMgr.getConfiguration(ArrayType.FF_WORKING_MEM);
|
||||||
|
WorkspaceConfiguration confOutput = workspaceMgr.getConfiguration(ArrayType.ACTIVATIONS);
|
||||||
|
boolean actScopedOut = workspaceMgr.isScopedOut(ArrayType.ACTIVATIONS);
|
||||||
|
Preconditions.checkState(actScopedOut || wsNameOutput != null, "Activations must have a workspace or must be scoped out");
|
||||||
|
SessionMemMgr mmgr = new DL4JSameDiffMemoryMgr(wsNameWorking, wsNameOutput, confWorking, confOutput);
|
||||||
|
|
||||||
|
InferenceSession is = sameDiff.getSessions().get(Thread.currentThread().getId());
|
||||||
|
if(is == null){
|
||||||
|
is = new InferenceSession(sameDiff);
|
||||||
|
sameDiff.getSessions().put(Thread.currentThread().getId(), is);
|
||||||
|
}
|
||||||
|
is.setMmgr(mmgr);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
||||||
//TODO Find a more efficient solution for this
|
//TODO Find a more efficient solution for this
|
||||||
|
@ -112,7 +134,7 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
||||||
phMap.put(LABELS_KEY, labels);
|
phMap.put(LABELS_KEY, labels);
|
||||||
}
|
}
|
||||||
|
|
||||||
String s = activations ? layerConf().activationsVertexName() : outputVar.getVarName();
|
String s = activations ? layerConf().activationsVertexName() : outputVar.name();
|
||||||
|
|
||||||
INDArray out = sameDiff.outputSingle(phMap, s);
|
INDArray out = sameDiff.outputSingle(phMap, s);
|
||||||
|
|
||||||
|
@ -120,16 +142,16 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
||||||
sameDiff.clearPlaceholders(true);
|
sameDiff.clearPlaceholders(true);
|
||||||
sameDiff.clearOpInputs();
|
sameDiff.clearOpInputs();
|
||||||
|
|
||||||
if(activations) {
|
//Edge case: vertex is just an Identity function, for example
|
||||||
Preconditions.checkNotNull(out, "Activations (result) array for variable \"%s\" was " +
|
//TODO there may be a cleaner way to do this...
|
||||||
"null - error during execution or this variable (as defined by method activationsVertexName()) " +
|
if(!actScopedOut && !out.data().getParentWorkspace().getId().equals(wsNameOutput)){
|
||||||
"does not exist", layerConf().activationsVertexName());
|
out = workspaceMgr.dup(ArrayType.ACTIVATIONS, out);
|
||||||
return workspaceMgr.dup(ArrayType.ACTIVATIONS, out);
|
} else if(actScopedOut && out.isAttached()){
|
||||||
} else {
|
out = out.detach();
|
||||||
|
}
|
||||||
|
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -141,12 +163,31 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
||||||
Gradient g = new DefaultGradient();
|
Gradient g = new DefaultGradient();
|
||||||
|
|
||||||
INDArray dLdIn;
|
INDArray dLdIn;
|
||||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
|
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||||
if(sameDiff == null){
|
if (sameDiff == null) {
|
||||||
//Usually doInit will be called in forward pass; not necessarily the case in output layers
|
//Usually doInit will be called in forward pass; not necessarily the case in output layers
|
||||||
// (for efficiency, we skip output layer forward pass in MultiLayerNetwork/ComputationGraph)
|
// (for efficiency, we skip output layer forward pass in MultiLayerNetwork/ComputationGraph)
|
||||||
doInit();
|
doInit();
|
||||||
}
|
}
|
||||||
|
if(sameDiff.getFunction("grad") == null)
|
||||||
|
sameDiff.createGradFunction(INPUT_KEY);
|
||||||
|
}
|
||||||
|
|
||||||
|
//Configure memory management for SameDiff instance - use DL4J workspaces
|
||||||
|
Map<Long,InferenceSession> sessionMap = sameDiff.getFunction("grad").getSessions();
|
||||||
|
if(!sessionMap.containsKey(Thread.currentThread().getId())){
|
||||||
|
sessionMap.put(Thread.currentThread().getId(), new InferenceSession(sameDiff.getFunction("grad")));
|
||||||
|
}
|
||||||
|
String wsNameWorking = workspaceMgr.getWorkspaceName(ArrayType.BP_WORKING_MEM);
|
||||||
|
String wsNameActGrad = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATION_GRAD);
|
||||||
|
WorkspaceConfiguration confWorking = workspaceMgr.getConfiguration(ArrayType.BP_WORKING_MEM);
|
||||||
|
WorkspaceConfiguration confOutput = workspaceMgr.getConfiguration(ArrayType.ACTIVATION_GRAD);
|
||||||
|
|
||||||
|
boolean actGradScopedOut = workspaceMgr.isScopedOut(ArrayType.ACTIVATION_GRAD);
|
||||||
|
Preconditions.checkState(actGradScopedOut || wsNameActGrad != null, "Activation gradients must have a workspace or be scoped out");
|
||||||
|
SessionMemMgr mmgr = new DL4JSameDiffMemoryMgr(wsNameWorking, wsNameActGrad, confWorking, confOutput);
|
||||||
|
sessionMap.get(Thread.currentThread().getId()).setMmgr(mmgr);
|
||||||
|
|
||||||
if(!sameDiff.hasGradientFunction()) {
|
if(!sameDiff.hasGradientFunction()) {
|
||||||
//Create when scoped out, to ensure any arrays are not in WS
|
//Create when scoped out, to ensure any arrays are not in WS
|
||||||
sameDiff.createGradFunction(INPUT_KEY);
|
sameDiff.createGradFunction(INPUT_KEY);
|
||||||
|
@ -160,31 +201,38 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
||||||
}
|
}
|
||||||
|
|
||||||
List<String> gradVarNames = new ArrayList<>();
|
List<String> gradVarNames = new ArrayList<>();
|
||||||
for(String s : paramTable.keySet()){
|
gradVarNames.addAll(paramTable.keySet());
|
||||||
gradVarNames.add(sameDiff.getVariable(s).getGradient().getVarName());
|
gradVarNames.add(INPUT_KEY);
|
||||||
}
|
|
||||||
gradVarNames.add(sameDiff.grad(INPUT_KEY).getVarName());
|
|
||||||
|
|
||||||
Map<String,INDArray> phMap = new HashMap<>();
|
Map<String,INDArray> phMap = new HashMap<>();
|
||||||
phMap.put(INPUT_KEY, input);
|
phMap.put(INPUT_KEY, input);
|
||||||
phMap.put(LABELS_KEY, labels);
|
phMap.put(LABELS_KEY, labels);
|
||||||
|
|
||||||
sameDiff.execBackwards(phMap, gradVarNames);
|
Map<String,INDArray> grads = sameDiff.calculateGradients(phMap, gradVarNames);
|
||||||
for(String s : paramTable.keySet() ){
|
for(String s : paramTable.keySet() ){
|
||||||
INDArray sdGrad = sameDiff.grad(s).getArr();
|
INDArray sdGrad = grads.get(s);
|
||||||
INDArray dl4jGrad = gradTable.get(s);
|
INDArray dl4jGrad = gradTable.get(s);
|
||||||
dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS
|
dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS
|
||||||
g.gradientForVariable().put(s, dl4jGrad);
|
g.gradientForVariable().put(s, dl4jGrad);
|
||||||
|
if(sdGrad.closeable()){
|
||||||
|
sdGrad.close();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
dLdIn = sameDiff.grad(INPUT_KEY).getArr();
|
dLdIn = grads.get(INPUT_KEY);
|
||||||
}
|
|
||||||
|
|
||||||
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
||||||
sameDiff.clearPlaceholders(true);
|
sameDiff.clearPlaceholders(true);
|
||||||
sameDiff.clearOpInputs();
|
sameDiff.clearOpInputs();
|
||||||
|
|
||||||
return new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS
|
//TODO there may be a cleaner way to do this...
|
||||||
|
if(!actGradScopedOut && !dLdIn.data().getParentWorkspace().getId().equals(wsNameActGrad)){
|
||||||
|
dLdIn = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn);
|
||||||
|
} else if(actGradScopedOut && dLdIn.isAttached()){
|
||||||
|
dLdIn = dLdIn.detach();
|
||||||
|
}
|
||||||
|
|
||||||
|
return new Pair<>(g, dLdIn);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**Returns the parameters of the neural network as a flattened row vector
|
/**Returns the parameters of the neural network as a flattened row vector
|
||||||
|
@ -297,7 +345,7 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
||||||
sameDiff.associateArrayWithVariable(arr, sameDiff.getVariable(e.getKey()));
|
sameDiff.associateArrayWithVariable(arr, sameDiff.getVariable(e.getKey()));
|
||||||
}
|
}
|
||||||
|
|
||||||
this.outputKey = layerOutput.getVarName();
|
this.outputKey = layerOutput.name();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -308,7 +356,8 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public double computeScore(double fullNetRegTerm, boolean training, LayerWorkspaceMgr workspaceMgr) {
|
public double computeScore(double fullNetRegTerm, boolean training, LayerWorkspaceMgr workspaceMgr) {
|
||||||
return (activateHelper(false, workspaceMgr).getDouble(0) + fullNetRegTerm) / input.size(0);
|
INDArray scoreArr = activateHelper(false, workspaceMgr);
|
||||||
|
return (scoreArr.getDouble(0) + fullNetRegTerm) / input.size(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -41,6 +41,7 @@ import org.nd4j.linalg.api.blas.Level1;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.exception.ND4JArraySizeException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.learning.regularization.Regularization;
|
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||||
import org.nd4j.linalg.lossfunctions.ILossFunction;
|
import org.nd4j.linalg.lossfunctions.ILossFunction;
|
||||||
|
@ -552,7 +553,8 @@ public class VariationalAutoencoder implements Layer {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int batchSize() {
|
public int batchSize() {
|
||||||
// FIXME: int cast
|
if (input.size(0) > Integer.MAX_VALUE)
|
||||||
|
throw new ND4JArraySizeException();
|
||||||
return (int) input.size(0);
|
return (int) input.size(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -862,7 +864,8 @@ public class VariationalAutoencoder implements Layer {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int getInputMiniBatchSize() {
|
public int getInputMiniBatchSize() {
|
||||||
// FIXME: int cast
|
if (input.size(0) > Integer.MAX_VALUE)
|
||||||
|
throw new ND4JArraySizeException();
|
||||||
return (int) input.size(0);
|
return (int) input.size(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -75,6 +75,7 @@ import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
|
import org.nd4j.linalg.exception.ND4JArraySizeException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.heartbeat.Heartbeat;
|
import org.nd4j.linalg.heartbeat.Heartbeat;
|
||||||
import org.nd4j.linalg.heartbeat.reports.Environment;
|
import org.nd4j.linalg.heartbeat.reports.Environment;
|
||||||
|
@ -425,7 +426,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
try(MemoryWorkspace ws = workspaceMgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM)) {
|
try(MemoryWorkspace ws = workspaceMgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM)) {
|
||||||
if (layerWiseConfigurations.getInputPreProcess(layerIdx) != null) {
|
if (layerWiseConfigurations.getInputPreProcess(layerIdx) != null) {
|
||||||
|
|
||||||
// FIXME: int cast
|
if (input.size(0) > Integer.MAX_VALUE)
|
||||||
|
throw new ND4JArraySizeException();
|
||||||
outputOfPrevLayer = layerWiseConfigurations.getInputPreProcess(layerIdx).preProcess(outputOfPrevLayer, (int) input.size(0),
|
outputOfPrevLayer = layerWiseConfigurations.getInputPreProcess(layerIdx).preProcess(outputOfPrevLayer, (int) input.size(0),
|
||||||
LayerWorkspaceMgr.noWorkspaces(helperWorkspaces));
|
LayerWorkspaceMgr.noWorkspaces(helperWorkspaces));
|
||||||
}
|
}
|
||||||
|
@ -439,7 +441,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
//In 99+% of cases, the input and labels dimension 0 size should be identical
|
//In 99+% of cases, the input and labels dimension 0 size should be identical
|
||||||
//The only real exceptions: space to batch, and batch to space layers
|
//The only real exceptions: space to batch, and batch to space layers
|
||||||
//In those cases, we should base it on the labels size, as this impacts gradient calculation
|
//In those cases, we should base it on the labels size, as this impacts gradient calculation
|
||||||
// FIXME: int cast
|
if (input.size(0) > Integer.MAX_VALUE || labels.size(0) > Integer.MAX_VALUE)
|
||||||
|
throw new ND4JArraySizeException();
|
||||||
return labels == null ? (int) input.size(0) : (int)labels.size(0);
|
return labels == null ? (int) input.size(0) : (int)labels.size(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2074,7 +2077,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
if (endTimeIdx > timeSeriesLength)
|
if (endTimeIdx > timeSeriesLength)
|
||||||
endTimeIdx = timeSeriesLength;
|
endTimeIdx = timeSeriesLength;
|
||||||
|
|
||||||
// FIXME: int cast
|
if (startTimeIdx > Integer.MAX_VALUE || endTimeIdx > Integer.MAX_VALUE)
|
||||||
|
throw new ND4JArraySizeException();
|
||||||
INDArray[] subsets = getSubsetsForTbptt((int) startTimeIdx, (int) endTimeIdx, input, labels,
|
INDArray[] subsets = getSubsetsForTbptt((int) startTimeIdx, (int) endTimeIdx, input, labels,
|
||||||
featuresMaskArray, labelsMaskArray);
|
featuresMaskArray, labelsMaskArray);
|
||||||
|
|
||||||
|
@ -2211,7 +2215,9 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
public int[] predict(INDArray d) {
|
public int[] predict(INDArray d) {
|
||||||
INDArray output = output(d, Layer.TrainingMode.TEST);
|
INDArray output = output(d, Layer.TrainingMode.TEST);
|
||||||
|
|
||||||
// FIXME: int cast
|
if (d.size(0) > Integer.MAX_VALUE)
|
||||||
|
throw new ND4JArraySizeException();
|
||||||
|
|
||||||
int[] ret = new int[(int) d.size(0)];
|
int[] ret = new int[(int) d.size(0)];
|
||||||
if (d.isRowVectorOrScalar())
|
if (d.isRowVectorOrScalar())
|
||||||
ret[0] = Nd4j.getBlasWrapper().iamax(output);
|
ret[0] = Nd4j.getBlasWrapper().iamax(output);
|
||||||
|
@ -2335,7 +2341,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
org.deeplearning4j.nn.conf.layers.OutputLayer layerConf =
|
org.deeplearning4j.nn.conf.layers.OutputLayer layerConf =
|
||||||
(org.deeplearning4j.nn.conf.layers.OutputLayer) getOutputLayer().conf().getLayer();
|
(org.deeplearning4j.nn.conf.layers.OutputLayer) getOutputLayer().conf().getLayer();
|
||||||
|
|
||||||
// FIXME: int cast
|
if (layerConf.getNOut() > Integer.MAX_VALUE)
|
||||||
|
throw new ND4JArraySizeException();
|
||||||
fit(examples, FeatureUtil.toOutcomeMatrix(labels, (int) layerConf.getNOut()));
|
fit(examples, FeatureUtil.toOutcomeMatrix(labels, (int) layerConf.getNOut()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2584,7 +2591,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
INDArray inputToOutputLayer = outputOfLayerDetached(training, FwdPassType.STANDARD,layers.length-2, data.getFeatures(),
|
INDArray inputToOutputLayer = outputOfLayerDetached(training, FwdPassType.STANDARD,layers.length-2, data.getFeatures(),
|
||||||
data.getFeaturesMaskArray(), data.getLabelsMaskArray(), null);
|
data.getFeaturesMaskArray(), data.getLabelsMaskArray(), null);
|
||||||
|
|
||||||
// FIXME: int cast
|
if (data.getFeatures().size(0) > Integer.MAX_VALUE)
|
||||||
|
throw new ND4JArraySizeException();
|
||||||
IOutputLayer ol = (IOutputLayer) getOutputLayer();
|
IOutputLayer ol = (IOutputLayer) getOutputLayer();
|
||||||
if (getLayerWiseConfigurations().getInputPreProcess(layers.length - 1) != null) {
|
if (getLayerWiseConfigurations().getInputPreProcess(layers.length - 1) != null) {
|
||||||
inputToOutputLayer = getLayerWiseConfigurations().getInputPreProcess(layers.length - 1)
|
inputToOutputLayer = getLayerWiseConfigurations().getInputPreProcess(layers.length - 1)
|
||||||
|
@ -2647,7 +2655,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
IOutputLayer ol = (IOutputLayer) getOutputLayer();
|
IOutputLayer ol = (IOutputLayer) getOutputLayer();
|
||||||
if(layerWiseConfigurations.getInputPreProcess(layers.length-1) != null){
|
if(layerWiseConfigurations.getInputPreProcess(layers.length-1) != null){
|
||||||
|
|
||||||
// FIXME: int cast
|
if (data.getFeatures().size(0) > Integer.MAX_VALUE)
|
||||||
|
throw new ND4JArraySizeException();
|
||||||
inputLast = layerWiseConfigurations.getInputPreProcess(layers.length-1).preProcess(inputLast,
|
inputLast = layerWiseConfigurations.getInputPreProcess(layers.length-1).preProcess(inputLast,
|
||||||
(int) data.getFeatures().size(0), mgr);
|
(int) data.getFeatures().size(0), mgr);
|
||||||
}
|
}
|
||||||
|
@ -2811,7 +2820,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
"Invalid input: length 0 (shape: " + Arrays.toString(input.shape()) + ")");
|
"Invalid input: length 0 (shape: " + Arrays.toString(input.shape()) + ")");
|
||||||
|
|
||||||
// FIXME: int cast
|
if (input.size(0) > Integer.MAX_VALUE)
|
||||||
|
throw new ND4JArraySizeException();
|
||||||
setInputMiniBatchSize((int) input.size(0));
|
setInputMiniBatchSize((int) input.size(0));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3086,7 +3096,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
if(!conf().isMiniBatch())
|
if(!conf().isMiniBatch())
|
||||||
return 1;
|
return 1;
|
||||||
|
|
||||||
// FIXME: int cast
|
if (input.size(0) > Integer.MAX_VALUE)
|
||||||
|
throw new ND4JArraySizeException();
|
||||||
return (int) input.size(0);
|
return (int) input.size(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3256,7 +3267,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
public void setLayerMaskArrays(INDArray featuresMaskArray, INDArray labelsMaskArray) {
|
public void setLayerMaskArrays(INDArray featuresMaskArray, INDArray labelsMaskArray) {
|
||||||
if (featuresMaskArray != null) {
|
if (featuresMaskArray != null) {
|
||||||
|
|
||||||
// FIXME: int cast
|
if (featuresMaskArray.size(0) > Integer.MAX_VALUE)
|
||||||
|
throw new ND4JArraySizeException();
|
||||||
//New approach: use feedForwardMaskArray method
|
//New approach: use feedForwardMaskArray method
|
||||||
feedForwardMaskArray(featuresMaskArray, MaskState.Active, (int) featuresMaskArray.size(0));
|
feedForwardMaskArray(featuresMaskArray, MaskState.Active, (int) featuresMaskArray.size(0));
|
||||||
|
|
||||||
|
@ -3438,7 +3450,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
val startTimeIdx = i * fwdLen;
|
val startTimeIdx = i * fwdLen;
|
||||||
val endTimeIdx = Math.min(startTimeIdx + fwdLen, tsLength);
|
val endTimeIdx = Math.min(startTimeIdx + fwdLen, tsLength);
|
||||||
|
|
||||||
// FIXME: int cast
|
if (endTimeIdx > Integer.MAX_VALUE)
|
||||||
|
throw new ND4JArraySizeException();
|
||||||
INDArray[] subsets = getSubsetsForTbptt(startTimeIdx, (int) endTimeIdx, features, labels, fMask, lMask);
|
INDArray[] subsets = getSubsetsForTbptt(startTimeIdx, (int) endTimeIdx, features, labels, fMask, lMask);
|
||||||
|
|
||||||
setLayerMaskArrays(subsets[2], subsets[3]);
|
setLayerMaskArrays(subsets[2], subsets[3]);
|
||||||
|
@ -3943,7 +3956,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
}
|
}
|
||||||
FeedForwardLayer ffl = (FeedForwardLayer) conf;
|
FeedForwardLayer ffl = (FeedForwardLayer) conf;
|
||||||
|
|
||||||
// FIXME: int cast
|
if (ffl.getNOut() > Integer.MAX_VALUE)
|
||||||
|
throw new ND4JArraySizeException();
|
||||||
return (int) ffl.getNOut();
|
return (int) ffl.getNOut();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3969,7 +3983,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
}
|
}
|
||||||
FeedForwardLayer ffl = (FeedForwardLayer) conf;
|
FeedForwardLayer ffl = (FeedForwardLayer) conf;
|
||||||
|
|
||||||
// FIXME: int cast
|
if (ffl.getNIn() > Integer.MAX_VALUE)
|
||||||
|
throw new ND4JArraySizeException();
|
||||||
return (int) ffl.getNIn();
|
return (int) ffl.getNIn();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@ import org.deeplearning4j.nn.conf.layers.Layer;
|
||||||
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
|
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
|
||||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.exception.ND4JArraySizeException;
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -108,7 +109,8 @@ public class VariationalAutoencoderParamInitializer extends DefaultParamInitiali
|
||||||
}
|
}
|
||||||
|
|
||||||
//Between last decoder layer and parameters for p(x|z):
|
//Between last decoder layer and parameters for p(x|z):
|
||||||
// FIXME: int cast
|
if (nIn > Integer.MAX_VALUE)
|
||||||
|
throw new ND4JArraySizeException();
|
||||||
val nDistributionParams = layer.getOutputDistribution().distributionInputSize((int) nIn);
|
val nDistributionParams = layer.getOutputDistribution().distributionInputSize((int) nIn);
|
||||||
val lastDecLayerSize = decoderLayerSizes[decoderLayerSizes.length - 1];
|
val lastDecLayerSize = decoderLayerSizes[decoderLayerSizes.length - 1];
|
||||||
paramCount += (lastDecLayerSize + 1) * nDistributionParams;
|
paramCount += (lastDecLayerSize + 1) * nDistributionParams;
|
||||||
|
@ -294,7 +296,8 @@ public class VariationalAutoencoderParamInitializer extends DefaultParamInitiali
|
||||||
}
|
}
|
||||||
|
|
||||||
//Finally, p(x|z):
|
//Finally, p(x|z):
|
||||||
// FIXME: int cast
|
if (nIn > Integer.MAX_VALUE)
|
||||||
|
throw new ND4JArraySizeException();
|
||||||
int nDistributionParams = layer.getOutputDistribution().distributionInputSize((int) nIn);
|
int nDistributionParams = layer.getOutputDistribution().distributionInputSize((int) nIn);
|
||||||
int pxzWeightCount = decoderLayerSizes[decoderLayerSizes.length - 1] * nDistributionParams;
|
int pxzWeightCount = decoderLayerSizes[decoderLayerSizes.length - 1] * nDistributionParams;
|
||||||
INDArray pxzWeightView =
|
INDArray pxzWeightView =
|
||||||
|
@ -402,7 +405,8 @@ public class VariationalAutoencoderParamInitializer extends DefaultParamInitiali
|
||||||
}
|
}
|
||||||
|
|
||||||
//Finally, p(x|z):
|
//Finally, p(x|z):
|
||||||
// FIXME: int cast
|
if (nIn > Integer.MAX_VALUE)
|
||||||
|
throw new ND4JArraySizeException();
|
||||||
int nDistributionParams = layer.getOutputDistribution().distributionInputSize((int) nIn);
|
int nDistributionParams = layer.getOutputDistribution().distributionInputSize((int) nIn);
|
||||||
int pxzWeightCount = decoderLayerSizes[decoderLayerSizes.length - 1] * nDistributionParams;
|
int pxzWeightCount = decoderLayerSizes[decoderLayerSizes.length - 1] * nDistributionParams;
|
||||||
INDArray pxzWeightView =
|
INDArray pxzWeightView =
|
||||||
|
|
|
@ -30,6 +30,7 @@ import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2;
|
import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2;
|
||||||
|
import org.nd4j.linalg.exception.ND4JArraySizeException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||||
|
@ -111,7 +112,8 @@ public abstract class BaseMultiLayerUpdater<T extends Model> implements Updater
|
||||||
if (currentBlock == null || !UpdaterUtils.updaterConfigurationsEquals(lastLayer, lastVariable,
|
if (currentBlock == null || !UpdaterUtils.updaterConfigurationsEquals(lastLayer, lastVariable,
|
||||||
layers[i], var)) {
|
layers[i], var)) {
|
||||||
|
|
||||||
// FIXME: int cast
|
if (paramsViewSoFar + paramSizeThisVariable > Integer.MAX_VALUE || paramsViewSoFar + paramSizeThisVariable > Integer.MAX_VALUE)
|
||||||
|
throw new ND4JArraySizeException();
|
||||||
//Create a new block
|
//Create a new block
|
||||||
List<UpdaterBlock.ParamState> list = new ArrayList<>();
|
List<UpdaterBlock.ParamState> list = new ArrayList<>();
|
||||||
list.add(new UpdaterBlock.ParamState(layers[i], var, paramsViewSoFar,
|
list.add(new UpdaterBlock.ParamState(layers[i], var, paramsViewSoFar,
|
||||||
|
@ -122,9 +124,11 @@ public abstract class BaseMultiLayerUpdater<T extends Model> implements Updater
|
||||||
|
|
||||||
updaterBlocks.add(currentBlock);
|
updaterBlocks.add(currentBlock);
|
||||||
} else {
|
} else {
|
||||||
// FIXME: int cast
|
long newOffset = currentBlock.getParamOffsetEnd() + paramSizeThisVariable;
|
||||||
|
if (newOffset > Integer.MAX_VALUE)
|
||||||
|
throw new ND4JArraySizeException();
|
||||||
//Add to existing updater block
|
//Add to existing updater block
|
||||||
currentBlock.setParamOffsetEnd((int) (currentBlock.getParamOffsetEnd() + paramSizeThisVariable));
|
currentBlock.setParamOffsetEnd((int) newOffset);
|
||||||
currentBlock.setUpdaterViewOffsetEnd(
|
currentBlock.setUpdaterViewOffsetEnd(
|
||||||
currentBlock.getUpdaterViewOffsetEnd() + updaterStateSizeThisVariable);
|
currentBlock.getUpdaterViewOffsetEnd() + updaterStateSizeThisVariable);
|
||||||
currentBlock.getLayersAndVariablesInBlock()
|
currentBlock.getLayersAndVariablesInBlock()
|
||||||
|
|
|
@ -25,6 +25,7 @@ import java.io.FileOutputStream;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.OutputStream;
|
import java.io.OutputStream;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -37,7 +38,83 @@ public class CollectScoresIterationListener extends BaseTrainingListener {
|
||||||
|
|
||||||
private int frequency;
|
private int frequency;
|
||||||
private int iterationCount = 0;
|
private int iterationCount = 0;
|
||||||
private List<Pair<Integer, Double>> scoreVsIter = new ArrayList<>();
|
//private List<Pair<Integer, Double>> scoreVsIter = new ArrayList<>();
|
||||||
|
|
||||||
|
public static class ScoreStat {
|
||||||
|
public static final int BUCKET_LENGTH = 10000;
|
||||||
|
|
||||||
|
private int position = 0;
|
||||||
|
private int bucketNumber = 1;
|
||||||
|
private List<long[]> indexes;
|
||||||
|
private List<double[]> scores;
|
||||||
|
|
||||||
|
public ScoreStat() {
|
||||||
|
indexes = new ArrayList<>(1);
|
||||||
|
indexes.add(new long[BUCKET_LENGTH]);
|
||||||
|
scores = new ArrayList<>(1);
|
||||||
|
scores.add(new double[BUCKET_LENGTH]);
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<long[]> getIndexes() {
|
||||||
|
return indexes;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<double[]> getScores() {
|
||||||
|
return scores;
|
||||||
|
}
|
||||||
|
|
||||||
|
public long[] getEffectiveIndexes() {
|
||||||
|
return Arrays.copyOfRange(indexes.get(0), 0, position);
|
||||||
|
}
|
||||||
|
|
||||||
|
public double[] getEffectiveScores() {
|
||||||
|
return Arrays.copyOfRange(scores.get(0), 0, position);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/*
|
||||||
|
Originally scores array is initialized with BUCKET_LENGTH size.
|
||||||
|
When data doesn't fit there - arrays size is increased for BUCKET_LENGTH,
|
||||||
|
old data is copied and bucketNumber (counter of reallocations) being incremented.
|
||||||
|
|
||||||
|
If we got more score points than MAX_VALUE - they are put to another item of scores list.
|
||||||
|
*/
|
||||||
|
private void reallocateGuard() {
|
||||||
|
if (position >= BUCKET_LENGTH * bucketNumber) {
|
||||||
|
|
||||||
|
long fullLength = (long)BUCKET_LENGTH * bucketNumber;
|
||||||
|
|
||||||
|
if (position == Integer.MAX_VALUE || fullLength >= Integer.MAX_VALUE) {
|
||||||
|
position = 0;
|
||||||
|
long[] newIndexes = new long[BUCKET_LENGTH];
|
||||||
|
double[] newScores = new double[BUCKET_LENGTH];
|
||||||
|
indexes.add(newIndexes);
|
||||||
|
scores.add(newScores);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
long[] newIndexes = new long[(int)fullLength + BUCKET_LENGTH];
|
||||||
|
double[] newScores = new double[(int)fullLength + BUCKET_LENGTH];
|
||||||
|
System.arraycopy(indexes.get(indexes.size()-1), 0, newIndexes, 0, (int)fullLength);
|
||||||
|
System.arraycopy(scores.get(scores.size()-1), 0, newScores, 0, (int)fullLength);
|
||||||
|
scores.remove(scores.size()-1);
|
||||||
|
indexes.remove(indexes.size()-1);
|
||||||
|
int lastIndex = scores.size() == 0 ? 0 : scores.size()-1;
|
||||||
|
scores.add(lastIndex, newScores);
|
||||||
|
indexes.add(lastIndex, newIndexes);
|
||||||
|
}
|
||||||
|
bucketNumber += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void addScore(long index, double score) {
|
||||||
|
reallocateGuard();
|
||||||
|
scores.get(scores.size() - 1)[position] = score;
|
||||||
|
indexes.get(scores.size() - 1)[position] = index;
|
||||||
|
position += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ScoreStat scoreVsIter = new ScoreStat();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Constructor for collecting scores with default saving frequency of 1
|
* Constructor for collecting scores with default saving frequency of 1
|
||||||
|
@ -60,11 +137,12 @@ public class CollectScoresIterationListener extends BaseTrainingListener {
|
||||||
public void iterationDone(Model model, int iteration, int epoch) {
|
public void iterationDone(Model model, int iteration, int epoch) {
|
||||||
if (++iterationCount % frequency == 0) {
|
if (++iterationCount % frequency == 0) {
|
||||||
double score = model.score();
|
double score = model.score();
|
||||||
scoreVsIter.add(new Pair<>(iterationCount, score));
|
scoreVsIter.reallocateGuard();
|
||||||
|
scoreVsIter.addScore(iteration, score);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<Pair<Integer, Double>> getScoreVsIter() {
|
public ScoreStat getScoreVsIter() {
|
||||||
return scoreVsIter;
|
return scoreVsIter;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -84,8 +162,16 @@ public class CollectScoresIterationListener extends BaseTrainingListener {
|
||||||
public void exportScores(OutputStream outputStream, String delimiter) throws IOException {
|
public void exportScores(OutputStream outputStream, String delimiter) throws IOException {
|
||||||
StringBuilder sb = new StringBuilder();
|
StringBuilder sb = new StringBuilder();
|
||||||
sb.append("Iteration").append(delimiter).append("Score");
|
sb.append("Iteration").append(delimiter).append("Score");
|
||||||
for (Pair<Integer, Double> p : scoreVsIter) {
|
int largeBuckets = scoreVsIter.indexes.size();
|
||||||
sb.append("\n").append(p.getFirst()).append(delimiter).append(p.getSecond());
|
for (int j = 0; j < largeBuckets; ++j) {
|
||||||
|
long[] indexes = scoreVsIter.indexes.get(j);
|
||||||
|
double[] scores = scoreVsIter.scores.get(j);
|
||||||
|
|
||||||
|
int effectiveLength = (j < largeBuckets -1) ? indexes.length : scoreVsIter.position;
|
||||||
|
|
||||||
|
for (int i = 0; i < effectiveLength; ++i) {
|
||||||
|
sb.append("\n").append(indexes[i]).append(delimiter).append(scores[i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
outputStream.write(sb.toString().getBytes("UTF-8"));
|
outputStream.write(sb.toString().getBytes("UTF-8"));
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,6 +29,7 @@ import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastCopyOp;
|
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastCopyOp;
|
||||||
import org.nd4j.linalg.api.shape.Shape;
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
|
import org.nd4j.linalg.exception.ND4JArraySizeException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
@ -62,10 +63,9 @@ public class Convolution1DUtils {
|
||||||
* @param dilation Kernel dilation
|
* @param dilation Kernel dilation
|
||||||
* @return Output size (width)
|
* @return Output size (width)
|
||||||
*/
|
*/
|
||||||
public static int getOutputSize(int inH, int kernel, int strides, int padding,
|
public static long getOutputSize(long inH, int kernel, int strides, int padding,
|
||||||
ConvolutionMode convolutionMode, int dilation) {
|
ConvolutionMode convolutionMode, int dilation) {
|
||||||
// FIXME: int cast
|
long eKernel = effectiveKernelSize(kernel, dilation);
|
||||||
int eKernel = effectiveKernelSize(kernel, dilation);
|
|
||||||
if (convolutionMode == ConvolutionMode.Same) {
|
if (convolutionMode == ConvolutionMode.Same) {
|
||||||
return (int) Math.ceil(inH / ((double) strides));
|
return (int) Math.ceil(inH / ((double) strides));
|
||||||
}
|
}
|
||||||
|
@ -85,7 +85,8 @@ public class Convolution1DUtils {
|
||||||
*/
|
*/
|
||||||
public static int getOutputSize(INDArray inputData, int kernel, int strides, int padding,
|
public static int getOutputSize(INDArray inputData, int kernel, int strides, int padding,
|
||||||
ConvolutionMode convolutionMode, int dilation) {
|
ConvolutionMode convolutionMode, int dilation) {
|
||||||
// FIXME: int cast
|
if (inputData.size(2) > Integer.MAX_VALUE)
|
||||||
|
throw new ND4JArraySizeException();
|
||||||
int inH = (int) inputData.size(2);
|
int inH = (int) inputData.size(2);
|
||||||
int eKernel = effectiveKernelSize(kernel, dilation);
|
int eKernel = effectiveKernelSize(kernel, dilation);
|
||||||
boolean atrous = (eKernel == kernel);
|
boolean atrous = (eKernel == kernel);
|
||||||
|
|
|
@ -61,15 +61,14 @@ public class Convolution3DUtils {
|
||||||
ConvolutionMode convolutionMode, int[] dilation, boolean isNCDHW) {
|
ConvolutionMode convolutionMode, int[] dilation, boolean isNCDHW) {
|
||||||
|
|
||||||
// NCDHW vs. NDHWC
|
// NCDHW vs. NDHWC
|
||||||
int inD = (int) (isNCDHW ? inputData.size(2) : inputData.size(1));
|
long inD = (isNCDHW ? inputData.size(2) : inputData.size(1));
|
||||||
int inH = (int) (isNCDHW ? inputData.size(3) : inputData.size(2));
|
long inH = (isNCDHW ? inputData.size(3) : inputData.size(2));
|
||||||
int inW = (int) (isNCDHW ? inputData.size(4) : inputData.size(3));
|
long inW = (isNCDHW ? inputData.size(4) : inputData.size(3));
|
||||||
|
|
||||||
int[] eKernel = effectiveKernelSize(kernel, dilation);
|
int[] eKernel = effectiveKernelSize(kernel, dilation);
|
||||||
boolean atrous = (eKernel == kernel);
|
boolean atrous = (eKernel == kernel);
|
||||||
|
|
||||||
// FIXME: int cast
|
val inShape = new long[]{inD, inH, inW};
|
||||||
val inShape = new int[]{inD, inH, inW};
|
|
||||||
validateShapes(ArrayUtil.toInts(inputData.shape()), eKernel, strides, padding, convolutionMode, dilation, inShape, atrous);
|
validateShapes(ArrayUtil.toInts(inputData.shape()), eKernel, strides, padding, convolutionMode, dilation, inShape, atrous);
|
||||||
|
|
||||||
if (convolutionMode == ConvolutionMode.Same) {
|
if (convolutionMode == ConvolutionMode.Same) {
|
||||||
|
@ -80,16 +79,16 @@ public class Convolution3DUtils {
|
||||||
return new int[]{outD, outH, outW};
|
return new int[]{outD, outH, outW};
|
||||||
}
|
}
|
||||||
|
|
||||||
int outD = (inD - eKernel[0] + 2 * padding[0]) / strides[0] + 1;
|
int outD = ((int)inD - eKernel[0] + 2 * padding[0]) / strides[0] + 1;
|
||||||
int outH = (inH - eKernel[1] + 2 * padding[1]) / strides[1] + 1;
|
int outH = ((int)inH - eKernel[1] + 2 * padding[1]) / strides[1] + 1;
|
||||||
int outW = (inW - eKernel[2] + 2 * padding[2]) / strides[2] + 1;
|
int outW = ((int)inW - eKernel[2] + 2 * padding[2]) / strides[2] + 1;
|
||||||
|
|
||||||
return new int[]{outD, outH, outW};
|
return new int[]{outD, outH, outW};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private static void validateShapes(int[] inputDataShape, int[] eKernel, int[] strides, int[] padding,
|
private static void validateShapes(int[] inputDataShape, int[] eKernel, int[] strides, int[] padding,
|
||||||
ConvolutionMode convolutionMode, int[] dilation, int[] inShape,
|
ConvolutionMode convolutionMode, int[] dilation, long[] inShape,
|
||||||
boolean atrous) {
|
boolean atrous) {
|
||||||
|
|
||||||
String[] dims = new String[]{"depth", "height", "width"};
|
String[] dims = new String[]{"depth", "height", "width"};
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue