Merge remote-tracking branch 'konduit/master'

master
AlexDBlack 2019-11-02 19:00:47 +11:00
commit 2844f8b69a
463 changed files with 15151 additions and 12178 deletions

View File

@ -124,7 +124,6 @@ public class TestUtils {
public static INDArray randomOneHot(long examples, long nOut, Random rng){ public static INDArray randomOneHot(long examples, long nOut, Random rng){
INDArray arr = Nd4j.create(examples, nOut); INDArray arr = Nd4j.create(examples, nOut);
for( int i=0; i<examples; i++ ){ for( int i=0; i<examples; i++ ){
// FIXME: int cast
arr.putScalar(i, rng.nextInt((int) nOut), 1.0); arr.putScalar(i, rng.nextInt((int) nOut), 1.0);
} }
return arr; return arr;

View File

@ -0,0 +1,107 @@
package org.deeplearning4j;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.mkldnn.MKLDNNBatchNormHelper;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import java.lang.reflect.Field;
import static junit.framework.TestCase.*;
public class TestBatchNormBp {
@Test
public void test(){
Nd4j.getRandom().setSeed(12345);
// INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 4, 4);
INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 15, 15);
INDArray mean = in.mean(0, 2, 3); //Nd4j.rand(DataType.FLOAT, 3);
INDArray var = in.var(0, 2, 3); //Nd4j.rand(DataType.FLOAT, 3);
INDArray eps = Nd4j.rand(DataType.FLOAT, in.shape());
// INDArray gamma = Nd4j.ones(DataType.FLOAT, 3);
// INDArray beta = Nd4j.zeros(DataType.FLOAT, 3);
INDArray gamma = Nd4j.rand(DataType.FLOAT, 3);
INDArray beta = Nd4j.rand(DataType.FLOAT, 3);
double e = 1e-5;
INDArray dLdIn = in.ulike();
INDArray dLdm = mean.ulike();
INDArray dLdv = var.ulike();
INDArray dLdg = gamma.ulike();
INDArray dLdb = beta.ulike();
DynamicCustomOp op = DynamicCustomOp.builder("batchnorm_bp")
.addInputs(in, mean, var, eps, gamma, beta)
.addIntegerArguments(
1, //Apply scale
1, //Apply beta
1) //Axis (NCHW)
.addFloatingPointArguments(e)
.addOutputs(dLdIn, dLdm, dLdv, dLdg, dLdb)
.build();
Nd4j.exec(op);
System.out.println(dLdIn);
}
@Test
public void compareImpls() throws Exception {
Nd4j.getRandom().setSeed(12345);
INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 15, 15);
INDArray mean = in.mean(0, 2, 3).reshape(1,3);
INDArray var = in.var(0, 2, 3).reshape(1,3);
INDArray eps = Nd4j.rand(DataType.FLOAT, in.shape());
INDArray gamma = Nd4j.rand(DataType.FLOAT, 1,3);
INDArray beta = Nd4j.rand(DataType.FLOAT, 1,3);
double e = 1e-3;
INDArray dLdIn = in.ulike();
INDArray dLdm = mean.ulike();
INDArray dLdv = var.ulike();
INDArray dLdg = gamma.ulike();
INDArray dLdb = beta.ulike();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.inferenceWorkspaceMode(WorkspaceMode.NONE)
.trainingWorkspaceMode(WorkspaceMode.NONE)
.list()
.layer(new BatchNormalization.Builder().nIn(3).nOut(3).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
org.deeplearning4j.nn.layers.normalization.BatchNormalization bn = (org.deeplearning4j.nn.layers.normalization.BatchNormalization) net.getLayer(0);
assertNotNull(bn.getHelper());
Field f = bn.getClass().getDeclaredField("helper");
f.setAccessible(true);
f.set(bn, null);
assertNull(bn.getHelper());
MKLDNNBatchNormHelper h = new MKLDNNBatchNormHelper(DataType.FLOAT);
net.output(in, true);
bn.setInput(in, LayerWorkspaceMgr.noWorkspaces());
Pair<Gradient,INDArray> p = net.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces());
h.preOutput(in, true, new long[]{1,3}, gamma, beta, mean, var, 0.5, e, LayerWorkspaceMgr.noWorkspaces());
Pair<Gradient,INDArray> pmkl = h.backpropGradient(in, eps, new long[]{1,3}, gamma, beta, dLdg, dLdb, e, LayerWorkspaceMgr.noWorkspaces());
INDArray dldin_dl4j = p.getSecond();
System.out.println("dl4j == mkldnn: " + p.getSecond().equals(pmkl.getSecond()));
}
}

View File

@ -132,7 +132,6 @@ public class TestUtils {
public static INDArray randomOneHot(long examples, long nOut, Random rng){ public static INDArray randomOneHot(long examples, long nOut, Random rng){
INDArray arr = Nd4j.create(examples, nOut); INDArray arr = Nd4j.create(examples, nOut);
for( int i=0; i<examples; i++ ){ for( int i=0; i<examples; i++ ){
// FIXME: int cast
arr.putScalar(i, rng.nextInt((int) nOut), 1.0); arr.putScalar(i, rng.nextInt((int) nOut), 1.0);
} }
return arr; return arr;

View File

@ -187,7 +187,6 @@ public class AsyncDataSetIteratorTest extends BaseDL4JTest {
} }
} }
@Test @Test
public void testVariableTimeSeries2() throws Exception { public void testVariableTimeSeries2() throws Exception {
AsyncDataSetIterator adsi = AsyncDataSetIterator adsi =

View File

@ -36,6 +36,7 @@ import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.CollectScoresIterationListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.junit.Ignore; import org.junit.Ignore;
import org.junit.Test; import org.junit.Test;
@ -242,7 +243,10 @@ public class DataSetIteratorTest extends BaseDL4JTest {
MultiLayerNetwork model = new MultiLayerNetwork(builder.build()); MultiLayerNetwork model = new MultiLayerNetwork(builder.build());
model.init(); model.init();
model.setListeners(Arrays.asList((TrainingListener) new ScoreIterationListener(listenerFreq))); //model.setListeners(Arrays.asList((TrainingListener) new ScoreIterationListener(listenerFreq)));
CollectScoresIterationListener listener = new CollectScoresIterationListener(listenerFreq);
model.setListeners(listener);
model.fit(cifar); model.fit(cifar);
@ -254,6 +258,7 @@ public class DataSetIteratorTest extends BaseDL4JTest {
eval.eval(testDS.getLabels(), output); eval.eval(testDS.getLabels(), output);
} }
System.out.println(eval.stats(true)); System.out.println(eval.stats(true));
listener.exportScores(System.out);
} }

View File

@ -464,13 +464,11 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
ret[1] = Nd4j.zeros(labelsShape); ret[1] = Nd4j.zeros(labelsShape);
if (labelsShape.length == 2) { if (labelsShape.length == 2) {
for (int i = 0; i < labelsShape[0]; i++) { for (int i = 0; i < labelsShape[0]; i++) {
// FIXME: int cast
ret[1].putScalar(i, r.nextInt((int) labelsShape[1]), 1.0); ret[1].putScalar(i, r.nextInt((int) labelsShape[1]), 1.0);
} }
} else if (labelsShape.length == 3) { } else if (labelsShape.length == 3) {
for (int i = 0; i < labelsShape[0]; i++) { for (int i = 0; i < labelsShape[0]; i++) {
for (int j = 0; j < labelsShape[2]; j++) { for (int j = 0; j < labelsShape[2]; j++) {
// FIXME: int cast
ret[1].putScalar(i, r.nextInt((int) labelsShape[1]), j, 1.0); ret[1].putScalar(i, r.nextInt((int) labelsShape[1]), j, 1.0);
} }
} }
@ -484,13 +482,11 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
ret[1] = Nd4j.ones(labelsShape); ret[1] = Nd4j.ones(labelsShape);
if (labelsShape.length == 2) { if (labelsShape.length == 2) {
for (int i = 0; i < labelsShape[0]; i++) { for (int i = 0; i < labelsShape[0]; i++) {
// FIXME: int cast
ret[1].putScalar(i, r.nextInt((int) labelsShape[1]), -1.0); ret[1].putScalar(i, r.nextInt((int) labelsShape[1]), -1.0);
} }
} else if (labelsShape.length == 3) { } else if (labelsShape.length == 3) {
for (int i = 0; i < labelsShape[0]; i++) { for (int i = 0; i < labelsShape[0]; i++) {
for (int j = 0; j < labelsShape[2]; j++) { for (int j = 0; j < labelsShape[2]; j++) {
// FIXME: int cast
ret[1].putScalar(i, r.nextInt((int) labelsShape[1]), j, -1.0); ret[1].putScalar(i, r.nextInt((int) labelsShape[1]), j, -1.0);
} }
} }

View File

@ -176,8 +176,7 @@ public class ShiftVertexTest extends BaseDL4JTest {
manual_weights.put("output_b", c); manual_weights.put("output_b", c);
// First things first, let's calculate the score. // First things first, let's calculate the score.
// FIXME: int cast long batchsz = input.shape()[0];
int batchsz = (int) input.shape()[0];
INDArray z = input.castTo(W.dataType()).mmul(W).add(b.repmat(batchsz, 1)); INDArray z = input.castTo(W.dataType()).mmul(W).add(b.repmat(batchsz, 1));
INDArray a = a1.getActivation(z.dup(), true).add(sf); // activation modifies it's input!! INDArray a = a1.getActivation(z.dup(), true).add(sf); // activation modifies it's input!!
INDArray q = a.mmul(V).add(c.repmat(batchsz, 1)); INDArray q = a.mmul(V).add(c.repmat(batchsz, 1));

View File

@ -157,8 +157,8 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
@Test @Test
public void testMultiLayerNetworkFrozenLayerParamsAfterBackprop() { public void testMultiLayerNetworkFrozenLayerParamsAfterBackprop() {
Nd4j.getRandom().setSeed(12345);
DataSet randomData = new DataSet(Nd4j.rand(100, 4, 12345), Nd4j.rand(100, 1, 12345)); DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1));
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder()
.seed(12345) .seed(12345)
@ -194,8 +194,9 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
@Test @Test
public void testComputationGraphFrozenLayerParamsAfterBackprop() { public void testComputationGraphFrozenLayerParamsAfterBackprop() {
Nd4j.getRandom().setSeed(12345);
DataSet randomData = new DataSet(Nd4j.rand(100, 4,12345), Nd4j.rand(100, 1, 12345)); DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1));
String frozenBranchName = "B1-"; String frozenBranchName = "B1-";
String unfrozenBranchName = "B2-"; String unfrozenBranchName = "B2-";
@ -254,43 +255,18 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
*/ */
@Test @Test
public void testFrozenLayerVsSgd() { public void testFrozenLayerVsSgd() {
DataSet randomData = new DataSet(Nd4j.rand(100, 4, 12345), Nd4j.rand(100, 1, 12345)); Nd4j.getRandom().setSeed(12345);
DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1));
MultiLayerConfiguration confSgd = new NeuralNetConfiguration.Builder() MultiLayerConfiguration confSgd = new NeuralNetConfiguration.Builder()
.seed(12345) .seed(12345)
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.updater(new Sgd(2)) .updater(new Sgd(2))
.list() .list()
.layer(0, .layer(0,new DenseLayer.Builder().nIn(4).nOut(3).build())
new DenseLayer.Builder() .layer(1,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(3).nOut(4).build())
.nIn(4) .layer(2,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(4).nOut(2).build())
.nOut(3) .layer(3,new OutputLayer.Builder(LossFunctions.LossFunction.MSE).updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).activation(Activation.TANH).nIn(2).nOut(1).build())
.build()
)
.layer(1,
new DenseLayer.Builder()
.updater(new Sgd(0.0))
.biasUpdater(new Sgd(0.0))
.nIn(3)
.nOut(4)
.build()
).layer(2,
new DenseLayer.Builder()
.updater(new Sgd(0.0))
.biasUpdater(new Sgd(0.0))
.nIn(4)
.nOut(2)
.build()
).layer(3,
new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
.updater(new Sgd(0.0))
.biasUpdater(new Sgd(0.0))
.activation(Activation.TANH)
.nIn(2)
.nOut(1)
.build()
)
.build(); .build();
MultiLayerConfiguration confFrozen = new NeuralNetConfiguration.Builder() MultiLayerConfiguration confFrozen = new NeuralNetConfiguration.Builder()
@ -298,36 +274,10 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.updater(new Sgd(2)) .updater(new Sgd(2))
.list() .list()
.layer(0, .layer(0,new DenseLayer.Builder().nIn(4).nOut(3).build())
new DenseLayer.Builder() .layer(1,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(3).nOut(4).build()))
.nIn(4) .layer(2,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(4).nOut(2).build()))
.nOut(3) .layer(3,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(2).nOut(1).build()))
.build()
)
.layer(1,
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new DenseLayer.Builder()
.nIn(3)
.nOut(4)
.build()
)
)
.layer(2,
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new DenseLayer.Builder()
.nIn(4)
.nOut(2)
.build()
)
).layer(3,
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
.activation(Activation.TANH)
.nIn(2)
.nOut(1)
.build()
)
)
.build(); .build();
MultiLayerNetwork frozenNetwork = new MultiLayerNetwork(confFrozen); MultiLayerNetwork frozenNetwork = new MultiLayerNetwork(confFrozen);
frozenNetwork.init(); frozenNetwork.init();
@ -359,8 +309,8 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
@Test @Test
public void testComputationGraphVsSgd() { public void testComputationGraphVsSgd() {
Nd4j.getRandom().setSeed(12345);
DataSet randomData = new DataSet(Nd4j.rand(100, 4, 12345), Nd4j.rand(100, 1, 12345)); DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1));
String frozenBranchName = "B1-"; String frozenBranchName = "B1-";
String unfrozenBranchName = "B2-"; String unfrozenBranchName = "B2-";
@ -381,71 +331,19 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
.seed(12345) .seed(12345)
.graphBuilder() .graphBuilder()
.addInputs("input") .addInputs("input")
.addLayer(initialLayer, .addLayer(initialLayer,new DenseLayer.Builder().nIn(4).nOut(4).build(),"input")
new DenseLayer.Builder() .addLayer(frozenBranchUnfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(3).build(), initialLayer)
.nIn(4) .addLayer(frozenBranchFrozenLayer1,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
.nOut(4) new DenseLayer.Builder().nIn(3).nOut(4).build()),frozenBranchUnfrozenLayer0)
.build(),
"input"
)
.addLayer(frozenBranchUnfrozenLayer0,
new DenseLayer.Builder()
.nIn(4)
.nOut(3)
.build(),
initialLayer
)
.addLayer(frozenBranchFrozenLayer1,
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new DenseLayer.Builder()
.nIn(3)
.nOut(4)
.build()
),
frozenBranchUnfrozenLayer0
)
.addLayer(frozenBranchFrozenLayer2, .addLayer(frozenBranchFrozenLayer2,
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new DenseLayer.Builder() new DenseLayer.Builder().nIn(4).nOut(2).build()),frozenBranchFrozenLayer1)
.nIn(4) .addLayer(unfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(4).build(),initialLayer)
.nOut(2) .addLayer(unfrozenLayer1,new DenseLayer.Builder().nIn(4).nOut(2).build(),unfrozenLayer0)
.build() .addLayer(unfrozenBranch2,new DenseLayer.Builder().nIn(2).nOut(1).build(),unfrozenLayer1)
), .addVertex("merge",new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2)
frozenBranchFrozenLayer1 .addLayer(frozenBranchOutput, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
) new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(3).nOut(1).build()),"merge")
.addLayer(unfrozenLayer0,
new DenseLayer.Builder()
.nIn(4)
.nOut(4)
.build(),
initialLayer
)
.addLayer(unfrozenLayer1,
new DenseLayer.Builder()
.nIn(4)
.nOut(2)
.build(),
unfrozenLayer0
)
.addLayer(unfrozenBranch2,
new DenseLayer.Builder()
.nIn(2)
.nOut(1)
.build(),
unfrozenLayer1
)
.addVertex("merge",
new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2)
.addLayer(frozenBranchOutput,
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
.activation(Activation.TANH)
.nIn(3)
.nOut(1)
.build()
),
"merge"
)
.setOutputs(frozenBranchOutput) .setOutputs(frozenBranchOutput)
.build(); .build();
@ -454,73 +352,15 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
.seed(12345) .seed(12345)
.graphBuilder() .graphBuilder()
.addInputs("input") .addInputs("input")
.addLayer(initialLayer, .addLayer(initialLayer, new DenseLayer.Builder().nIn(4).nOut(4).build(),"input")
new DenseLayer.Builder() .addLayer(frozenBranchUnfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(3).build(),initialLayer)
.nIn(4) .addLayer(frozenBranchFrozenLayer1,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(3).nOut(4).build(),frozenBranchUnfrozenLayer0)
.nOut(4) .addLayer(frozenBranchFrozenLayer2,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(4).nOut(2).build(),frozenBranchFrozenLayer1)
.build(), .addLayer(unfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(4).build(),initialLayer)
"input" .addLayer(unfrozenLayer1,new DenseLayer.Builder().nIn(4).nOut(2).build(),unfrozenLayer0)
) .addLayer(unfrozenBranch2,new DenseLayer.Builder().nIn(2).nOut(1).build(),unfrozenLayer1)
.addLayer(frozenBranchUnfrozenLayer0, .addVertex("merge",new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2)
new DenseLayer.Builder() .addLayer(frozenBranchOutput,new OutputLayer.Builder(LossFunctions.LossFunction.MSE).updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).activation(Activation.TANH).nIn(3).nOut(1).build(),"merge")
.nIn(4)
.nOut(3)
.build(),
initialLayer
)
.addLayer(frozenBranchFrozenLayer1,
new DenseLayer.Builder()
.updater(new Sgd(0.0))
.biasUpdater(new Sgd(0.0))
.nIn(3)
.nOut(4)
.build(),
frozenBranchUnfrozenLayer0
)
.addLayer(frozenBranchFrozenLayer2,
new DenseLayer.Builder()
.updater(new Sgd(0.0))
.biasUpdater(new Sgd(0.0))
.nIn(4)
.nOut(2)
.build()
,
frozenBranchFrozenLayer1
)
.addLayer(unfrozenLayer0,
new DenseLayer.Builder()
.nIn(4)
.nOut(4)
.build(),
initialLayer
)
.addLayer(unfrozenLayer1,
new DenseLayer.Builder()
.nIn(4)
.nOut(2)
.build(),
unfrozenLayer0
)
.addLayer(unfrozenBranch2,
new DenseLayer.Builder()
.nIn(2)
.nOut(1)
.build(),
unfrozenLayer1
)
.addVertex("merge",
new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2)
.addLayer(frozenBranchOutput,
new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
.updater(new Sgd(0.0))
.biasUpdater(new Sgd(0.0))
.activation(Activation.TANH)
.nIn(3)
.nOut(1)
.build()
,
"merge"
)
.setOutputs(frozenBranchOutput) .setOutputs(frozenBranchOutput)
.build(); .build();

View File

@ -21,6 +21,7 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.exception.DL4JException; import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.ConvolutionMode;
@ -693,4 +694,22 @@ public class ConvolutionLayerTest extends BaseDL4JTest {
INDArray out = net.output(in); INDArray out = net.output(in);
assertArrayEquals(new long[]{2,7,6}, out.shape()); assertArrayEquals(new long[]{2,7,6}, out.shape());
} }
@Test
public void testDeconvBadInput(){
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.list()
.layer(new Deconvolution2D.Builder().nIn(5).nOut(3).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
INDArray badInput = Nd4j.create(DataType.FLOAT, 1, 10, 5, 5);
try {
net.output(badInput);
} catch (DL4JInvalidInputException e){
String msg = e.getMessage();
assertTrue(msg,msg.contains("Deconvolution2D") && msg.contains("input") && msg.contains("channels"));
}
}
} }

View File

@ -80,6 +80,7 @@ public class TestSameDiffDense extends BaseDL4JTest {
@Test @Test
public void testSameDiffDenseForward() { public void testSameDiffDenseForward() {
for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) {
for (int minibatch : new int[]{5, 1}) { for (int minibatch : new int[]{5, 1}) {
int nIn = 3; int nIn = 3;
int nOut = 4; int nOut = 4;
@ -97,8 +98,10 @@ public class TestSameDiffDense extends BaseDL4JTest {
}; };
for (Activation a : afns) { for (Activation a : afns) {
log.info("Starting test - " + a); log.info("Starting test - " + a + ", workspace = " + wsm);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.inferenceWorkspaceMode(wsm)
.trainingWorkspaceMode(wsm)
.list() .list()
.layer(new SameDiffDense.Builder().nIn(nIn).nOut(nOut) .layer(new SameDiffDense.Builder().nIn(nIn).nOut(nOut)
.activation(a) .activation(a)
@ -146,9 +149,11 @@ public class TestSameDiffDense extends BaseDL4JTest {
} }
} }
} }
}
@Test @Test
public void testSameDiffDenseForwardMultiLayer() { public void testSameDiffDenseForwardMultiLayer() {
for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) {
for (int minibatch : new int[]{5, 1}) { for (int minibatch : new int[]{5, 1}) {
int nIn = 3; int nIn = 3;
int nOut = 4; int nOut = 4;
@ -166,7 +171,7 @@ public class TestSameDiffDense extends BaseDL4JTest {
}; };
for (Activation a : afns) { for (Activation a : afns) {
log.info("Starting test - " + a); log.info("Starting test - " + a + " - workspace=" + wsm);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(12345) .seed(12345)
.list() .list()
@ -201,7 +206,6 @@ public class TestSameDiffDense extends BaseDL4JTest {
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init(); net2.init();
// net.params().assign(net2.params());
assertEquals(net2.params(), net.params()); assertEquals(net2.params(), net.params());
//Check params: //Check params:
@ -231,6 +235,7 @@ public class TestSameDiffDense extends BaseDL4JTest {
} }
} }
} }
}
@Test @Test
public void testSameDiffDenseBackward() { public void testSameDiffDenseBackward() {
@ -244,10 +249,13 @@ public class TestSameDiffDense extends BaseDL4JTest {
Activation[] afns = new Activation[]{ Activation[] afns = new Activation[]{
Activation.TANH, Activation.TANH,
Activation.SIGMOID, Activation.SIGMOID,
Activation.ELU, Activation.IDENTITY, Activation.SOFTPLUS, Activation.SOFTSIGN, Activation.ELU,
Activation.IDENTITY,
Activation.SOFTPLUS,
Activation.SOFTSIGN,
Activation.HARDTANH, Activation.HARDTANH,
Activation.CUBE, //https://github.com/deeplearning4j/nd4j/issues/2426 Activation.CUBE,
Activation.RELU //JVM crash Activation.RELU
}; };
for (Activation a : afns) { for (Activation a : afns) {
@ -337,12 +345,13 @@ public class TestSameDiffDense extends BaseDL4JTest {
int nIn = 4; int nIn = 4;
int nOut = 3; int nOut = 3;
boolean workspaces = true;
for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(12345) .seed(12345)
.trainingWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE) .trainingWorkspaceMode(wsm)
.inferenceWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE) .inferenceWorkspaceMode(wsm)
.updater(new Adam(0.1)) .updater(new Adam(0.1))
.list() .list()
.layer(new SameDiffDense.Builder().nIn(nIn).nOut(5).activation(Activation.TANH).build()) .layer(new SameDiffDense.Builder().nIn(nIn).nOut(5).activation(Activation.TANH).build())
@ -396,13 +405,14 @@ public class TestSameDiffDense extends BaseDL4JTest {
INDArray outMb = netStandard.output(newIn); INDArray outMb = netStandard.output(newIn);
assertEquals(outMb, outMbsd); assertEquals(outMb, outMbsd);
} }
}
@Test @Test
public void gradientCheck() { public void gradientCheck() {
int nIn = 4; int nIn = 4;
int nOut = 4; int nOut = 4;
for (boolean workspaces : new boolean[]{false, true}) { for (boolean workspaces : new boolean[]{true, false}) {
for (Activation a : new Activation[]{Activation.TANH, Activation.IDENTITY}) { for (Activation a : new Activation[]{Activation.TANH, Activation.IDENTITY}) {
String msg = "workspaces: " + workspaces + ", " + a; String msg = "workspaces: " + workspaces + ", " + a;

View File

@ -21,6 +21,7 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils; import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.graph.ElementWiseVertex; import org.deeplearning4j.nn.conf.graph.ElementWiseVertex;
import org.deeplearning4j.nn.conf.graph.ScaleVertex; import org.deeplearning4j.nn.conf.graph.ScaleVertex;
import org.deeplearning4j.nn.conf.graph.ShiftVertex; import org.deeplearning4j.nn.conf.graph.ShiftVertex;
@ -52,8 +53,14 @@ public class TestSameDiffLambda extends BaseDL4JTest {
@Test @Test
public void testSameDiffLamdaLayerBasic(){ public void testSameDiffLamdaLayerBasic(){
for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) {
log.info("--- Workspace Mode: {} ---", wsm);
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.trainingWorkspaceMode(wsm)
.inferenceWorkspaceMode(wsm)
.seed(12345) .seed(12345)
.updater(new Adam(0.01)) .updater(new Adam(0.01))
.graphBuilder() .graphBuilder()
@ -67,6 +74,8 @@ public class TestSameDiffLambda extends BaseDL4JTest {
//Equavalent, not using SameDiff Lambda: //Equavalent, not using SameDiff Lambda:
ComputationGraphConfiguration confStd = new NeuralNetConfiguration.Builder() ComputationGraphConfiguration confStd = new NeuralNetConfiguration.Builder()
.trainingWorkspaceMode(wsm)
.inferenceWorkspaceMode(wsm)
.seed(12345) .seed(12345)
.updater(new Adam(0.01)) .updater(new Adam(0.01))
.graphBuilder() .graphBuilder()
@ -122,11 +131,17 @@ public class TestSameDiffLambda extends BaseDL4JTest {
INDArray outMb = std.output(newIn)[0]; INDArray outMb = std.output(newIn)[0];
assertEquals(outMb, outMbsd); assertEquals(outMb, outMbsd);
} }
}
@Test @Test
public void testSameDiffLamdaVertexBasic(){ public void testSameDiffLamdaVertexBasic(){
for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) {
log.info("--- Workspace Mode: {} ---", wsm);
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.trainingWorkspaceMode(wsm)
.inferenceWorkspaceMode(wsm)
.dataType(DataType.DOUBLE) .dataType(DataType.DOUBLE)
.seed(12345) .seed(12345)
.updater(new Adam(0.01)) .updater(new Adam(0.01))
@ -142,6 +157,8 @@ public class TestSameDiffLambda extends BaseDL4JTest {
//Equavalent, not using SameDiff Lambda: //Equavalent, not using SameDiff Lambda:
ComputationGraphConfiguration confStd = new NeuralNetConfiguration.Builder() ComputationGraphConfiguration confStd = new NeuralNetConfiguration.Builder()
.trainingWorkspaceMode(wsm)
.inferenceWorkspaceMode(wsm)
.dataType(DataType.DOUBLE) .dataType(DataType.DOUBLE)
.seed(12345) .seed(12345)
.updater(new Adam(0.01)) .updater(new Adam(0.01))
@ -201,3 +218,4 @@ public class TestSameDiffLambda extends BaseDL4JTest {
assertEquals(outMb, outMbsd); assertEquals(outMb, outMbsd);
} }
} }
}

View File

@ -23,10 +23,13 @@ import org.deeplearning4j.datasets.iterator.impl.SingletonDataSetIterator;
import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.Ignore; import org.junit.Ignore;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
@ -36,10 +39,13 @@ import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.primitives.Pair;
import java.lang.reflect.Field;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import static junit.framework.TestCase.*;
import static org.junit.Assume.assumeTrue; import static org.junit.Assume.assumeTrue;
public class ValidateMKLDNN extends BaseDL4JTest { public class ValidateMKLDNN extends BaseDL4JTest {
@ -148,7 +154,7 @@ public class ValidateMKLDNN extends BaseDL4JTest {
.padding(0, 0) .padding(0, 0)
.nOut(3) .nOut(3)
.build()) .build())
.layer(new BatchNormalization.Builder().cudnnAllowFallback(false).build()) .layer(new BatchNormalization.Builder().helperAllowFallback(false)/*.eps(0)*/.build())
.layer(new ConvolutionLayer.Builder().activation(Activation.TANH) .layer(new ConvolutionLayer.Builder().activation(Activation.TANH)
.kernelSize(kernel) .kernelSize(kernel)
.stride(stride) .stride(stride)
@ -256,4 +262,54 @@ public class ValidateMKLDNN extends BaseDL4JTest {
} }
} }
} }
@Test
public void compareBatchNormBackward() throws Exception {
Nd4j.getRandom().setSeed(12345);
INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 15, 15);
INDArray mean = in.mean(0, 2, 3).reshape(1,3);
INDArray var = in.var(0, 2, 3).reshape(1,3);
INDArray eps = Nd4j.rand(DataType.FLOAT, in.shape());
INDArray gamma = Nd4j.rand(DataType.FLOAT, 1,3);
INDArray beta = Nd4j.rand(DataType.FLOAT, 1,3);
double e = 1e-3;
INDArray dLdIn = in.ulike();
INDArray dLdm = mean.ulike();
INDArray dLdv = var.ulike();
INDArray dLdg = gamma.ulike();
INDArray dLdb = beta.ulike();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.inferenceWorkspaceMode(WorkspaceMode.NONE)
.trainingWorkspaceMode(WorkspaceMode.NONE)
.list()
.layer(new BatchNormalization.Builder().nIn(3).nOut(3).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
org.deeplearning4j.nn.layers.normalization.BatchNormalization bn = (org.deeplearning4j.nn.layers.normalization.BatchNormalization) net.getLayer(0);
assertNotNull(bn.getHelper());
System.out.println(bn.getHelper());
net.output(in, true);
bn.setInput(in, LayerWorkspaceMgr.noWorkspaces());
Pair<Gradient,INDArray> pcudnn = net.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces());
Field f = bn.getClass().getDeclaredField("helper");
f.setAccessible(true);
f.set(bn, null);
assertNull(bn.getHelper());
net.output(in, true);
bn.setInput(in, LayerWorkspaceMgr.noWorkspaces());
Pair<Gradient,INDArray> p = net.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces());
INDArray dldin_dl4j = p.getSecond();
INDArray dldin_helper = pcudnn.getSecond();
assertTrue(dldin_dl4j.equalsWithEps(dldin_helper, 1e-5));
}
} }

View File

@ -36,6 +36,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative; import org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
@ -340,7 +341,8 @@ public class BackPropMLPTest extends BaseDL4JTest {
public static float[] asFloat(INDArray arr) { public static float[] asFloat(INDArray arr) {
long len = arr.length(); long len = arr.length();
// FIXME: int cast if (len > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
float[] f = new float[(int) len]; float[] f = new float[(int) len];
NdIndexIterator iterator = new NdIndexIterator('c', arr.shape()); NdIndexIterator iterator = new NdIndexIterator('c', arr.shape());
for (int i = 0; i < len; i++) { for (int i = 0; i < len; i++) {

View File

@ -320,7 +320,6 @@ public class MultiLayerTest extends BaseDL4JTest {
public static float[] asFloat(INDArray arr) { public static float[] asFloat(INDArray arr) {
long len = arr.length(); long len = arr.length();
// FIXME: int cast
float[] f = new float[(int) len]; float[] f = new float[(int) len];
for (int i = 0; i < len; i++) for (int i = 0; i < len; i++)
f[i] = arr.getFloat(i); f[i] = arr.getFloat(i);

View File

@ -331,7 +331,6 @@ public class TestUpdaters extends BaseDL4JTest {
double calculatedByHandMScalar = 0.2; double calculatedByHandMScalar = 0.2;
double[] expectedM = Nd4j.ones(1, numParams).mul(calculatedByHandMScalar).data().asDouble(); double[] expectedM = Nd4j.ones(1, numParams).mul(calculatedByHandMScalar).data().asDouble();
// FIXME: int cast
double[] actualM = Arrays.copyOfRange(nadamUpdater.getM().data().asDouble(), 0, (int) numParams); double[] actualM = Arrays.copyOfRange(nadamUpdater.getM().data().asDouble(), 0, (int) numParams);
for (int i = 0; i < actualM.length; i++) { for (int i = 0; i < actualM.length; i++) {
actualM[i] = Math.round(actualM[i] * 1e2) / 1e2; actualM[i] = Math.round(actualM[i] * 1e2) / 1e2;

View File

@ -48,6 +48,7 @@ import org.nd4j.linalg.api.rng.DefaultRandom;
import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Condition; import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.learning.config.AdaGrad; import org.nd4j.linalg.learning.config.AdaGrad;
@ -664,7 +665,9 @@ public class TestOptimizers extends BaseDL4JTest {
double xlm1 = parameters.getDouble(nDims - 2); double xlm1 = parameters.getDouble(nDims - 2);
double gl = 200 * (xl - xlm1 * xlm1); double gl = 200 * (xl - xlm1 * xlm1);
// FIXME: int cast if (nDims - 1 > Integer.MAX_VALUE) {
throw new ND4JArraySizeException();
}
gradient.put(0, (int)nDims - 1, gl); gradient.put(0, (int)nDims - 1, gl);
Gradient g = new DefaultGradient(); Gradient g = new DefaultGradient();
g.gradientForVariable().put("W", gradient); g.gradientForVariable().put("W", gradient);
@ -865,8 +868,7 @@ public class TestOptimizers extends BaseDL4JTest {
@Override @Override
public long numParams() { public long numParams() {
// FIXME: int cast return parameters.length();
return (int) parameters.length();
} }
@Override @Override

View File

@ -86,10 +86,10 @@ public class CompareTrainingImplementations extends BaseDL4JTest {
SDVariable label = sd.placeHolder("label", DataType.DOUBLE, -1, 3); SDVariable label = sd.placeHolder("label", DataType.DOUBLE, -1, 3);
SDVariable w0 = sd.var("w0", new XavierInitScheme('c', 4, 10), DataType.DOUBLE, 4, 10); SDVariable w0 = sd.var("w0", new XavierInitScheme('c', 4, 10), DataType.DOUBLE, 4, 10);
SDVariable b0 = sd.zero("b0", 1, 10); SDVariable b0 = sd.var("b0", Nd4j.create(DataType.DOUBLE, 1, 10));
SDVariable w1 = sd.var("w1", new XavierInitScheme('c', 10, 3), DataType.DOUBLE, 10, 3); SDVariable w1 = sd.var("w1", new XavierInitScheme('c', 10, 3), DataType.DOUBLE, 10, 3);
SDVariable b1 = sd.zero("b1", 1, 3); SDVariable b1 = sd.var("b1", Nd4j.create(DataType.DOUBLE, 1, 3));
SDVariable z0 = in.mmul(w0).add(b0); SDVariable z0 = in.mmul(w0).add(b0);
SDVariable a0 = sd.nn().tanh(z0); SDVariable a0 = sd.nn().tanh(z0);
@ -172,8 +172,8 @@ public class CompareTrainingImplementations extends BaseDL4JTest {
Map<String,INDArray> placeholders = new HashMap<>(); Map<String,INDArray> placeholders = new HashMap<>();
placeholders.put("input", f); placeholders.put("input", f);
placeholders.put("label", l); placeholders.put("label", l);
sd.exec(placeholders, lossMse.getVarName()); Map<String,INDArray> map = sd.output(placeholders, lossMse.name(), a1.name());
INDArray outSd = a1.getArr(); INDArray outSd = map.get(a1.name());
INDArray outDl4j = net.output(f); INDArray outDl4j = net.output(f);
assertEquals(testName, outDl4j, outSd); assertEquals(testName, outDl4j, outSd);
@ -187,7 +187,7 @@ public class CompareTrainingImplementations extends BaseDL4JTest {
//Check score //Check score
double scoreDl4j = net.score(); double scoreDl4j = net.score();
double scoreSd = lossMse.getArr().getDouble(0) + sd.calcRegularizationScore(); double scoreSd = map.get(lossMse.name()).getDouble(0) + sd.calcRegularizationScore();
assertEquals(testName, scoreDl4j, scoreSd, 1e-6); assertEquals(testName, scoreDl4j, scoreSd, 1e-6);
double lossRegScoreSD = sd.calcRegularizationScore(); double lossRegScoreSD = sd.calcRegularizationScore();
@ -197,15 +197,15 @@ public class CompareTrainingImplementations extends BaseDL4JTest {
//Check gradients (before updater applied) //Check gradients (before updater applied)
Map<String,INDArray> grads = net.gradient().gradientForVariable(); Map<String,INDArray> grads = net.gradient().gradientForVariable();
sd.execBackwards(placeholders); Map<String,INDArray> gm = sd.calculateGradients(placeholders, b1.name(), w1.name(), b0.name(), w0.name());
//Note that the SameDiff gradients don't include the L1/L2 terms at present just from execBackwards()... these are added in fitting only //Note that the SameDiff gradients don't include the L1/L2 terms at present just from execBackwards()... these are added in fitting only
//We can check correctness though with training param checks later //We can check correctness though with training param checks later
if(l1Val == 0 && l2Val == 0 && wdVal == 0) { if(l1Val == 0 && l2Val == 0 && wdVal == 0) {
assertEquals(testName, grads.get("1_b"), b1.getGradient().getArr()); assertEquals(testName, grads.get("1_b"), gm.get(b1.name()));
assertEquals(testName, grads.get("1_W"), w1.getGradient().getArr()); assertEquals(testName, grads.get("1_W"), gm.get(w1.name()));
assertEquals(testName, grads.get("0_b"), b0.getGradient().getArr()); assertEquals(testName, grads.get("0_b"), gm.get(b0.name()));
assertEquals(testName, grads.get("0_W"), w0.getGradient().getArr()); assertEquals(testName, grads.get("0_W"), gm.get(w0.name()));
} }

View File

@ -123,7 +123,7 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
} }
@Override @Override
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, int[] shape, INDArray gamma, public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, INDArray beta,
INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr layerWorkspaceMgr) { INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr layerWorkspaceMgr) {
this.eps = eps; this.eps = eps;
val miniBatch = (int) input.size(0); val miniBatch = (int) input.size(0);
@ -173,8 +173,8 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, depth, inH, inW, checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, depth, inH, inW,
dstStride[0], dstStride[1], dstStride[2], dstStride[3])); dstStride[0], dstStride[1], dstStride[2], dstStride[3]));
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(gamma.data().dataType()), shape[0], checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(gamma.data().dataType()), (int)shape[0],
shape[1], shape.length > 2 ? shape[2] : 1, shape.length > 3 ? shape[3] : 1)); (int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1));
Allocator allocator = AtomicAllocator.getInstance(); Allocator allocator = AtomicAllocator.getInstance();
CudaContext context = allocator.getFlowController().prepareActionAllWrite(input, epsilon, nextEpsilon, gamma, CudaContext context = allocator.getFlowController().prepareActionAllWrite(input, epsilon, nextEpsilon, gamma,
@ -189,7 +189,7 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
Pointer varCacheData = allocator.getPointer(varCache, context); Pointer varCacheData = allocator.getPointer(varCache, context);
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream()))); checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
checkCudnn(cudnnBatchNormalizationBackward(cudnnContext, batchNormMode, alpha, beta, alpha, alpha, checkCudnn(cudnnBatchNormalizationBackward(cudnnContext, batchNormMode, alpha, this.beta, alpha, alpha,
cudnnContext.srcTensorDesc, srcData, cudnnContext.deltaTensorDesc, epsData, cudnnContext.srcTensorDesc, srcData, cudnnContext.deltaTensorDesc, epsData,
cudnnContext.dstTensorDesc, dstData, cudnnContext.gammaBetaTensorDesc, gammaData, dGammaData, cudnnContext.dstTensorDesc, dstData, cudnnContext.gammaBetaTensorDesc, gammaData, dGammaData,
dBetaData, eps, meanCacheData, varCacheData)); dBetaData, eps, meanCacheData, varCacheData));
@ -214,7 +214,7 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
@Override @Override
public INDArray preOutput(INDArray x, boolean training, int[] shape, INDArray gamma, INDArray beta, INDArray mean, public INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean,
INDArray var, double decay, double eps, LayerWorkspaceMgr workspaceMgr) { INDArray var, double decay, double eps, LayerWorkspaceMgr workspaceMgr) {
this.eps = eps; this.eps = eps;
final boolean isHalf = (x.dataType() == DataType.HALF); final boolean isHalf = (x.dataType() == DataType.HALF);
@ -252,8 +252,8 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, inDepth, inH, inW, checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, inDepth, inH, inW,
dstStride[0], dstStride[1], dstStride[2], dstStride[3])); dstStride[0], dstStride[1], dstStride[2], dstStride[3]));
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(mean.data().dataType()), shape[0], checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(mean.data().dataType()), (int)shape[0],
shape[1], shape.length > 2 ? shape[2] : 1, shape.length > 3 ? shape[3] : 1)); (int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1));
Allocator allocator = AtomicAllocator.getInstance(); Allocator allocator = AtomicAllocator.getInstance();
CudaContext context = CudaContext context =

View File

@ -286,7 +286,7 @@ public class RecordReaderMultiDataSetIterator implements MultiDataSetIterator, S
for (INDArray w : exampleData) { for (INDArray w : exampleData) {
val n = w.size(0); val n = w.size(0);
// FIXME: int cast if (Math.min(minExamples, n) < Integer.MAX_VALUE)
minExamples = (int) Math.min(minExamples, n); minExamples = (int) Math.min(minExamples, n);
} }
} }

View File

@ -366,7 +366,6 @@ public class SequenceRecordReaderDataSetIterator implements DataSetIterator {
DataSet ds = mdsToDataSet(mds); DataSet ds = mdsToDataSet(mds);
if (totalOutcomes == -1) { if (totalOutcomes == -1) {
// FIXME: int cast
inputColumns = (int) ds.getFeatures().size(1); inputColumns = (int) ds.getFeatures().size(1);
totalOutcomes = ds.getLabels() == null ? -1 : (int) ds.getLabels().size(1); totalOutcomes = ds.getLabels() == null ? -1 : (int) ds.getLabels().size(1);
} }
@ -394,7 +393,6 @@ public class SequenceRecordReaderDataSetIterator implements DataSetIterator {
stored = next(); stored = next();
useStored = true; useStored = true;
// FIXME: int cast
inputColumns = (int) stored.getFeatures().size(1); inputColumns = (int) stored.getFeatures().size(1);
totalOutcomes = (int) stored.getLabels().size(1); totalOutcomes = (int) stored.getLabels().size(1);
} }

View File

@ -172,7 +172,6 @@ public abstract class AbstractDataSetIterator<T> implements DataSetIterator {
Pair<T, T> pair = iterator.next(); Pair<T, T> pair = iterator.next();
if (numFeatures < 1) { if (numFeatures < 1) {
if (pair.getFirst() instanceof INDArray) { if (pair.getFirst() instanceof INDArray) {
// FIXME: int cast
numFeatures = (int) ((INDArray) pair.getFirst()).length(); numFeatures = (int) ((INDArray) pair.getFirst()).length();
numLabels = (int) ((INDArray) pair.getSecond()).length(); numLabels = (int) ((INDArray) pair.getSecond()).length();
} else if (pair.getFirst() instanceof float[]) { } else if (pair.getFirst() instanceof float[]) {

View File

@ -95,7 +95,6 @@ public class IteratorDataSetIterator implements DataSetIterator {
//Set columns etc for later use //Set columns etc for later use
DataSet temp = list.get(0); DataSet temp = list.get(0);
// FIXME: int cast
inputColumns = (int) temp.getFeatures().size(1); inputColumns = (int) temp.getFeatures().size(1);
totalOutcomes = temp.getLabels() == null ? 0 : (int) temp.getLabels().size(1); //May be null for layerwise pretraining totalOutcomes = temp.getLabels() == null ? 0 : (int) temp.getLabels().size(1); //May be null for layerwise pretraining
} }

View File

@ -73,8 +73,7 @@ public class IteratorMultiDataSetIterator implements MultiDataSetIterator {
next = iterator.next(); next = iterator.next();
} }
// FIXME: int cast long nExamples = next.getFeatures(0).size(0);
int nExamples = (int) next.getFeatures(0).size(0);
if (countSoFar + nExamples <= batchSize) { if (countSoFar + nExamples <= batchSize) {
//Add the entire MultiDataSet as-is //Add the entire MultiDataSet as-is
list.add(next); list.add(next);
@ -140,7 +139,7 @@ public class IteratorMultiDataSetIterator implements MultiDataSetIterator {
return out; return out;
} }
private static INDArray getRange(INDArray arr, int exampleFrom, int exampleToExclusive) { private static INDArray getRange(INDArray arr, long exampleFrom, long exampleToExclusive) {
if (arr == null) if (arr == null)
return null; return null;

View File

@ -134,7 +134,7 @@ public abstract class BaseFileIterator<T, P> implements Iterator<T> {
List<T> remainder = new ArrayList<>(); List<T> remainder = new ArrayList<>();
int soFar = 0; int soFar = 0;
for (T t : toMerge) { for (T t : toMerge) {
int size = sizeOf(t); long size = sizeOf(t);
if (soFar + size <= batchSize) { if (soFar + size <= batchSize) {
correctNum.add(t); correctNum.add(t);
@ -190,7 +190,7 @@ public abstract class BaseFileIterator<T, P> implements Iterator<T> {
protected abstract T load(File f); protected abstract T load(File f);
protected abstract int sizeOf(T of); protected abstract long sizeOf(T of);
protected abstract List<T> split(T toSplit); protected abstract List<T> split(T toSplit);

View File

@ -151,7 +151,7 @@ public class FileDataSetIterator extends BaseFileIterator<DataSet, DataSetPrePro
} }
@Override @Override
protected int sizeOf(DataSet of) { protected long sizeOf(DataSet of) {
return of.numExamples(); return of.numExamples();
} }

View File

@ -151,9 +151,8 @@ public class FileMultiDataSetIterator extends BaseFileIterator<MultiDataSet, Mul
} }
@Override @Override
protected int sizeOf(MultiDataSet of) { protected long sizeOf(MultiDataSet of) {
// FIXME: int cast return of.getFeatures(0).size(0);
return (int) of.getFeatures(0).size(0);
} }
@Override @Override

View File

@ -665,8 +665,7 @@ public class BarnesHutTsne implements Model {
if (useAdaGrad) { if (useAdaGrad) {
if (adaGrad == null) { if (adaGrad == null) {
// FIXME: int cast adaGrad = new AdaGrad(gradient.shape(), learningRate);
adaGrad = new AdaGrad(ArrayUtil.toInts(gradient.shape()), learningRate);
adaGrad.setStateViewArray(Nd4j.zeros(gradient.shape()).reshape(1, gradChange.length()), adaGrad.setStateViewArray(Nd4j.zeros(gradient.shape()).reshape(1, gradChange.length()),
gradChange.shape(), gradient.ordering(), true); gradChange.shape(), gradient.ordering(), true);
} }

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core;
import lombok.val; import lombok.val;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -51,6 +52,13 @@ public class KerasReshape extends KerasLayer {
this(layerConfig, true); this(layerConfig, true);
} }
private long[] listToLongArray(List<Integer> list) {
long[] retVal = new long[list.size()];
for (int i = 0; i < list.size(); ++i) {
retVal[i] = list.get(i);
}
return retVal;
}
/** /**
* Constructor from parsed Keras layer configuration dictionary. * Constructor from parsed Keras layer configuration dictionary.
* *
@ -67,9 +75,7 @@ public class KerasReshape extends KerasLayer {
if (innerConfig.containsKey(targetShape)) { if (innerConfig.containsKey(targetShape)) {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
List<Integer> targetShapeList = (List<Integer>) innerConfig.get(targetShape); List<Integer> targetShapeList = (List<Integer>) innerConfig.get(targetShape);
this.targetShape = listToLongArray(targetShapeList);
// FIXME: int cast
this.targetShape = ArrayUtil.toLongArray(ArrayUtil.toArray(targetShapeList));
} }
} }

View File

@ -690,13 +690,11 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
INDArray testLabels = Nd4j.create(predictionsDl4j.shape()); INDArray testLabels = Nd4j.create(predictionsDl4j.shape());
if (testLabels.rank() == 2) { if (testLabels.rank() == 2) {
for (int i = 0; i < testLabels.size(0); i++) { for (int i = 0; i < testLabels.size(0); i++) {
// FIXME: int cast
testLabels.putScalar(i, r.nextInt((int) testLabels.size(1)), 1.0); testLabels.putScalar(i, r.nextInt((int) testLabels.size(1)), 1.0);
} }
} else if (testLabels.rank() == 3) { } else if (testLabels.rank() == 3) {
for (int i = 0; i < testLabels.size(0); i++) { for (int i = 0; i < testLabels.size(0); i++) {
for (int j = 0; j < testLabels.size(1); j++) { for (int j = 0; j < testLabels.size(1); j++) {
// FIXME: int cast
testLabels.putScalar(i, j, r.nextInt((int) testLabels.size(1)), 1.0); testLabels.putScalar(i, j, r.nextInt((int) testLabels.size(1)), 1.0);
} }
} }

View File

@ -18,6 +18,9 @@ package org.deeplearning4j.clustering.kdtree;
import lombok.val; import lombok.val;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.custom.KnnMinDistance;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import java.io.Serializable; import java.io.Serializable;
import java.util.ArrayList; import java.util.ArrayList;
@ -28,79 +31,103 @@ import java.util.List;
*/ */
public class HyperRect implements Serializable { public class HyperRect implements Serializable {
private List<Interval> points; //private List<Interval> points;
private float[] lowerEnds;
private float[] higherEnds;
private INDArray lowerEndsIND;
private INDArray higherEndsIND;
public HyperRect(List<Interval> points) { public HyperRect(float[] lowerEndsIn, float[] higherEndsIn) {
//this.points = points; this.lowerEnds = new float[lowerEndsIn.length];
this.points = new ArrayList<>(points.size()); this.higherEnds = new float[lowerEndsIn.length];
for (int i = 0; i < points.size(); ++i) { System.arraycopy(lowerEndsIn, 0 , this.lowerEnds, 0, lowerEndsIn.length);
Interval newInterval = new Interval(points.get(i).lower, points.get(i).higher); System.arraycopy(higherEndsIn, 0 , this.higherEnds, 0, higherEndsIn.length);
this.points.add(newInterval); lowerEndsIND = Nd4j.createFromArray(lowerEnds);
higherEndsIND = Nd4j.createFromArray(higherEnds);
} }
public HyperRect(float[] point) {
this(point, point);
}
public HyperRect(Pair<float[], float[]> ends) {
this(ends.getFirst(), ends.getSecond());
} }
public void enlargeTo(INDArray point) { public void enlargeTo(INDArray point) {
for (int i = 0; i < points.size(); i++) float[] pointAsArray = point.toFloatVector();
points.get(i).enlarge(point.getDouble(i)); for (int i = 0; i < lowerEnds.length; i++) {
float p = pointAsArray[i];
if (lowerEnds[i] > p)
lowerEnds[i] = p;
else if (higherEnds[i] < p)
higherEnds[i] = p;
}
} }
public static Pair<float[],float[]> point(INDArray vector) {
public static List<Interval> point(INDArray vector) { Pair<float[],float[]> ret = new Pair<>();
List<Interval> ret = new ArrayList<>(); float[] curr = new float[(int)vector.length()];
for (int i = 0; i < vector.length(); i++) { for (int i = 0; i < vector.length(); i++) {
double curr = vector.getDouble(i); curr[i] = vector.getFloat(i);
ret.add(new Interval(curr, curr));
} }
ret.setFirst(curr);
ret.setSecond(curr);
return ret; return ret;
} }
public List<Boolean> contains(INDArray hPoint) { /*public List<Boolean> contains(INDArray hPoint) {
List<Boolean> ret = new ArrayList<>(); List<Boolean> ret = new ArrayList<>();
for (int i = 0; i < hPoint.length(); i++)
ret.add(points.get(i).contains(hPoint.getDouble(i)));
return ret;
}
public double minDistance(INDArray hPoint) {
double ret = 0.0;
for (int i = 0; i < hPoint.length(); i++) { for (int i = 0; i < hPoint.length(); i++) {
double p = hPoint.getDouble(i); ret.add(lowerEnds[i] <= hPoint.getDouble(i) &&
Interval interval = points.get(i); higherEnds[i] >= hPoint.getDouble(i));
if (!interval.contains(p)) {
if (p < interval.lower)
ret += Math.pow((p - interval.lower), 2);
else
ret += Math.pow((p - interval.higher), 2);
} }
}
ret = Math.pow(ret, 0.5);
return ret; return ret;
}*/
public double minDistance(INDArray hPoint, INDArray output) {
Nd4j.exec(new KnnMinDistance(hPoint, lowerEndsIND, higherEndsIND, output));
return output.getFloat(0);
/*double ret = 0.0;
double[] pointAsArray = hPoint.toDoubleVector();
for (int i = 0; i < pointAsArray.length; i++) {
double p = pointAsArray[i];
if (!(lowerEnds[i] <= p || higherEnds[i] <= p)) {
if (p < lowerEnds[i])
ret += Math.pow((p - lowerEnds[i]), 2);
else
ret += Math.pow((p - higherEnds[i]), 2);
}
}
ret = Math.pow(ret, 0.5);
return ret;*/
} }
public HyperRect getUpper(INDArray hPoint, int desc) { public HyperRect getUpper(INDArray hPoint, int desc) {
Interval interval = points.get(desc); //Interval interval = points.get(desc);
double d = hPoint.getDouble(desc); float higher = higherEnds[desc];
if (interval.higher < d) float d = hPoint.getFloat(desc);
if (higher < d)
return null; return null;
HyperRect ret = new HyperRect(new ArrayList<>(points)); HyperRect ret = new HyperRect(lowerEnds,higherEnds);
Interval i2 = ret.points.get(desc); if (ret.lowerEnds[desc] < d)
if (i2.lower < d) ret.lowerEnds[desc] = d;
i2.lower = d;
return ret; return ret;
} }
public HyperRect getLower(INDArray hPoint, int desc) { public HyperRect getLower(INDArray hPoint, int desc) {
Interval interval = points.get(desc); //Interval interval = points.get(desc);
double d = hPoint.getDouble(desc); float lower = lowerEnds[desc];
if (interval.lower > d) float d = hPoint.getFloat(desc);
if (lower > d)
return null; return null;
HyperRect ret = new HyperRect(new ArrayList<>(points)); HyperRect ret = new HyperRect(lowerEnds,higherEnds);
Interval i2 = ret.points.get(desc); //Interval i2 = ret.points.get(desc);
if (i2.higher > d) if (ret.higherEnds[desc] > d)
i2.higher = d; ret.higherEnds[desc] = d;
return ret; return ret;
} }
@ -108,33 +135,10 @@ public class HyperRect implements Serializable {
public String toString() { public String toString() {
String retVal = ""; String retVal = "";
retVal += "["; retVal += "[";
for (val point : points) { for (int i = 0; i < lowerEnds.length; ++i) {
retVal += "(" + point.lower + " - " + point.higher + ") "; retVal += "(" + lowerEnds[i] + " - " + higherEnds[i] + ") ";
} }
retVal += "]"; retVal += "]";
return retVal; return retVal;
} }
public static class Interval {
private double lower, higher;
public Interval(double lower, double higher) {
this.lower = lower;
this.higher = higher;
}
public boolean contains(double point) {
return lower <= point || point <= higher;
}
public void enlarge(double p) {
if (lower > p)
lower = p;
else if (higher < p)
higher = p;
}
}
} }

View File

@ -56,7 +56,7 @@ public class KDTree implements Serializable {
if (root == null) { if (root == null) {
root = new KDNode(point); root = new KDNode(point);
rect = new HyperRect(HyperRect.point(point)); rect = new HyperRect(/*HyperRect.point(point)*/ point.toFloatVector());
} else { } else {
int disc = 0; int disc = 0;
KDNode node = root; KDNode node = root;
@ -125,15 +125,21 @@ public class KDTree implements Serializable {
return node.getPoint(); return node.getPoint();
} }
// Share this data for recursive calls of "knn"
private float currentDistance;
private INDArray currentPoint;
private INDArray minDistance = Nd4j.scalar(0.f);
public List<Pair<Double, INDArray>> knn(INDArray point, double distance) { public List<Pair<Float, INDArray>> knn(INDArray point, float distance) {
List<Pair<Double, INDArray>> best = new ArrayList<>(); List<Pair<Float, INDArray>> best = new ArrayList<>();
knn(root, point, rect, distance, best, 0); currentDistance = distance;
Collections.sort(best, new Comparator<Pair<Double, INDArray>>() { currentPoint = point;
knn(root, rect, best, 0);
Collections.sort(best, new Comparator<Pair<Float, INDArray>>() {
@Override @Override
public int compare(Pair<Double, INDArray> o1, Pair<Double, INDArray> o2) { public int compare(Pair<Float, INDArray> o1, Pair<Float, INDArray> o2) {
return Double.compare(o1.getKey(), o2.getKey()); return Float.compare(o1.getKey(), o2.getKey());
} }
}); });
@ -141,22 +147,21 @@ public class KDTree implements Serializable {
} }
private void knn(KDNode node, INDArray point, HyperRect rect, double dist, List<Pair<Double, INDArray>> best, private void knn(KDNode node, HyperRect rect, List<Pair<Float, INDArray>> best, int _disc) {
int _disc) { if (node == null || rect == null || rect.minDistance(currentPoint, minDistance) > currentDistance)
if (node == null || rect == null || rect.minDistance(point) > dist)
return; return;
int _discNext = (_disc + 1) % dims; int _discNext = (_disc + 1) % dims;
double distance = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(point,node.point)).getFinalResult() float distance = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(currentPoint,node.point, minDistance)).getFinalResult()
.doubleValue(); .floatValue();
if (distance <= dist) { if (distance <= currentDistance) {
best.add(Pair.of(distance, node.getPoint())); best.add(Pair.of(distance, node.getPoint()));
} }
HyperRect lower = rect.getLower(node.point, _disc); HyperRect lower = rect.getLower(node.point, _disc);
HyperRect upper = rect.getUpper(node.point, _disc); HyperRect upper = rect.getUpper(node.point, _disc);
knn(node.getLeft(), point, lower, dist, best, _discNext); knn(node.getLeft(), lower, best, _discNext);
knn(node.getRight(), point, upper, dist, best, _discNext); knn(node.getRight(), upper, best, _discNext);
} }
/** /**
@ -171,7 +176,7 @@ public class KDTree implements Serializable {
private Pair<Double, INDArray> nn(KDNode node, INDArray point, HyperRect rect, double dist, INDArray best, private Pair<Double, INDArray> nn(KDNode node, INDArray point, HyperRect rect, double dist, INDArray best,
int _disc) { int _disc) {
if (node == null || rect.minDistance(point) > dist) if (node == null || rect.minDistance(point, minDistance) > dist)
return Pair.of(Double.POSITIVE_INFINITY, null); return Pair.of(Double.POSITIVE_INFINITY, null);
int _discNext = (_disc + 1) % dims; int _discNext = (_disc + 1) % dims;

View File

@ -16,6 +16,8 @@
package org.deeplearning4j.clustering.kdtree; package org.deeplearning4j.clustering.kdtree;
import org.joda.time.Instant;
import org.nd4j.shade.guava.base.Stopwatch;
import org.nd4j.shade.guava.primitives.Doubles; import org.nd4j.shade.guava.primitives.Doubles;
import lombok.val; import lombok.val;
import org.deeplearning4j.clustering.BaseDL4JTest; import org.deeplearning4j.clustering.BaseDL4JTest;
@ -28,6 +30,7 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.nd4j.shade.guava.primitives.Floats;
import org.opencv.ml.KNearest; import org.opencv.ml.KNearest;
import java.util.ArrayList; import java.util.ArrayList;
@ -35,6 +38,8 @@ import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
@ -53,17 +58,17 @@ public class KDTreeTest extends BaseDL4JTest {
@Before @Before
public void setUp() { public void setUp() {
kdTree = new KDTree(2); kdTree = new KDTree(2);
double[] data = new double[]{7,2}; float[] data = new float[]{7,2};
kdTree.insert(Nd4j.createFromArray(data)); kdTree.insert(Nd4j.createFromArray(data));
data = new double[]{5,4}; data = new float[]{5,4};
kdTree.insert(Nd4j.createFromArray(data)); kdTree.insert(Nd4j.createFromArray(data));
data = new double[]{2,3}; data = new float[]{2,3};
kdTree.insert(Nd4j.createFromArray(data)); kdTree.insert(Nd4j.createFromArray(data));
data = new double[]{4,7}; data = new float[]{4,7};
kdTree.insert(Nd4j.createFromArray(data)); kdTree.insert(Nd4j.createFromArray(data));
data = new double[]{9,6}; data = new float[]{9,6};
kdTree.insert(Nd4j.createFromArray(data)); kdTree.insert(Nd4j.createFromArray(data));
data = new double[]{8,1}; data = new float[]{8,1};
kdTree.insert(Nd4j.createFromArray(data)); kdTree.insert(Nd4j.createFromArray(data));
} }
@ -168,26 +173,30 @@ public class KDTreeTest extends BaseDL4JTest {
@Test @Test
public void testKNN() { public void testKNN() {
int n = 10; int dimensions = 512;
// make a KD-tree of dimension {#n} int vectorsNo = 50000;
KDTree kdTree = new KDTree(n); // make a KD-tree of dimension {#dimensions}
for (int i = -1; i < n; i++) { Stopwatch stopwatch = Stopwatch.createStarted();
KDTree kdTree = new KDTree(dimensions);
for (int i = -1; i < vectorsNo; i++) {
// Insert a unit vector along each dimension // Insert a unit vector along each dimension
List<Double> vec = new ArrayList<>(n); INDArray indVec = Nd4j.rand(DataType.FLOAT, 1,dimensions);
// i = -1 ensures the origin is in the Tree
for (int k = 0; k < n; k++) {
vec.add((k == i) ? 1.0 : 0.0);
}
INDArray indVec = Nd4j.create(Nd4j.createBuffer(Doubles.toArray(vec)));
kdTree.insert(indVec); kdTree.insert(indVec);
} }
stopwatch.stop();
System.out.println("Time elapsed for " + kdTree.size() + " nodes construction is "+ stopwatch.elapsed(SECONDS));
Random rand = new Random(); Random rand = new Random();
// random point in the Hypercube // random point in the Hypercube
List<Double> pt = new ArrayList(n); List<Double> pt = new ArrayList(dimensions);
for (int k = 0; k < n; k++) { for (int k = 0; k < dimensions; k++) {
pt.add(rand.nextDouble() * 10.0); pt.add(rand.nextFloat() * 10.0);
} }
List<Pair<Double, INDArray>> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Doubles.toArray(pt))), 20.0); stopwatch.reset();
stopwatch.start();
List<Pair<Float, INDArray>> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Floats.toArray(pt))), 20.0f);
stopwatch.stop();
System.out.println("Time elapsed for Search is "+ stopwatch.elapsed(MILLISECONDS));
} }
@Test @Test
@ -195,15 +204,15 @@ public class KDTreeTest extends BaseDL4JTest {
int n = 2; int n = 2;
KDTree kdTree = new KDTree(n); KDTree kdTree = new KDTree(n);
double[] data = new double[]{3,3}; float[] data = new float[]{3,3};
kdTree.insert(Nd4j.createFromArray(data)); kdTree.insert(Nd4j.createFromArray(data));
data = new double[]{1,1}; data = new float[]{1,1};
kdTree.insert(Nd4j.createFromArray(data)); kdTree.insert(Nd4j.createFromArray(data));
data = new double[]{2,2}; data = new float[]{2,2};
kdTree.insert(Nd4j.createFromArray(data)); kdTree.insert(Nd4j.createFromArray(data));
data = new double[]{0,0}; data = new float[]{0,0};
List<Pair<Double, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 4.5); List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 4.5f);
assertEquals(1.0, result.get(0).getSecond().getDouble(0), 1e-5); assertEquals(1.0, result.get(0).getSecond().getDouble(0), 1e-5);
assertEquals(1.0, result.get(0).getSecond().getDouble(1), 1e-5); assertEquals(1.0, result.get(0).getSecond().getDouble(1), 1e-5);
@ -220,88 +229,88 @@ public class KDTreeTest extends BaseDL4JTest {
assertEquals(6, kdTree.size()); assertEquals(6, kdTree.size());
double[] data = new double[]{8,1}; float[] data = new float[]{8,1};
List<Pair<Double, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 10.0); List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f);
assertEquals(8.0, result.get(0).getSecond().getDouble(0), 1e-5); assertEquals(8.0, result.get(0).getSecond().getFloat(0), 1e-5);
assertEquals(1.0, result.get(0).getSecond().getDouble(1), 1e-5); assertEquals(1.0, result.get(0).getSecond().getFloat(1), 1e-5);
assertEquals(7.0, result.get(1).getSecond().getDouble(0), 1e-5); assertEquals(7.0, result.get(1).getSecond().getFloat(0), 1e-5);
assertEquals(2.0, result.get(1).getSecond().getDouble(1), 1e-5); assertEquals(2.0, result.get(1).getSecond().getFloat(1), 1e-5);
assertEquals(5.0, result.get(2).getSecond().getDouble(0), 1e-5); assertEquals(5.0, result.get(2).getSecond().getFloat(0), 1e-5);
assertEquals(4.0, result.get(2).getSecond().getDouble(1), 1e-5); assertEquals(4.0, result.get(2).getSecond().getFloat(1), 1e-5);
assertEquals(9.0, result.get(3).getSecond().getDouble(0), 1e-5); assertEquals(9.0, result.get(3).getSecond().getFloat(0), 1e-5);
assertEquals(6.0, result.get(3).getSecond().getDouble(1), 1e-5); assertEquals(6.0, result.get(3).getSecond().getFloat(1), 1e-5);
assertEquals(2.0, result.get(4).getSecond().getDouble(0), 1e-5); assertEquals(2.0, result.get(4).getSecond().getFloat(0), 1e-5);
assertEquals(3.0, result.get(4).getSecond().getDouble(1), 1e-5); assertEquals(3.0, result.get(4).getSecond().getFloat(1), 1e-5);
assertEquals(4.0, result.get(5).getSecond().getDouble(0), 1e-5); assertEquals(4.0, result.get(5).getSecond().getFloat(0), 1e-5);
assertEquals(7.0, result.get(5).getSecond().getDouble(1), 1e-5); assertEquals(7.0, result.get(5).getSecond().getFloat(1), 1e-5);
} }
@Test @Test
public void testKNN_2() { public void testKNN_2() {
double[] data = new double[]{8, 1}; float[] data = new float[]{8, 1};
List<Pair<Double, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 5.0); List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f);
assertEquals(8.0, result.get(0).getSecond().getDouble(0), 1e-5); assertEquals(8.0, result.get(0).getSecond().getFloat(0), 1e-5);
assertEquals(1.0, result.get(0).getSecond().getDouble(1), 1e-5); assertEquals(1.0, result.get(0).getSecond().getFloat(1), 1e-5);
assertEquals(7.0, result.get(1).getSecond().getDouble(0), 1e-5); assertEquals(7.0, result.get(1).getSecond().getFloat(0), 1e-5);
assertEquals(2.0, result.get(1).getSecond().getDouble(1), 1e-5); assertEquals(2.0, result.get(1).getSecond().getFloat(1), 1e-5);
assertEquals(5.0, result.get(2).getSecond().getDouble(0), 1e-5); assertEquals(5.0, result.get(2).getSecond().getFloat(0), 1e-5);
assertEquals(4.0, result.get(2).getSecond().getDouble(1), 1e-5); assertEquals(4.0, result.get(2).getSecond().getFloat(1), 1e-5);
} }
@Test @Test
public void testKNN_3() { public void testKNN_3() {
double[] data = new double[]{2, 3}; float[] data = new float[]{2, 3};
val result = kdTree.knn(Nd4j.createFromArray(data), 10.0); List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f);
assertEquals(2.0, result.get(0).getSecond().getDouble(0), 1e-5); assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5);
assertEquals(3.0, result.get(0).getSecond().getDouble(1), 1e-5); assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5);
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5);
assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5); assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5);
assertEquals(4.0, result.get(2).getSecond().getDouble(0), 1e-5); assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5);
assertEquals(7.0, result.get(2).getSecond().getDouble(1), 1e-5); assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5);
assertEquals(7.0, result.get(3).getSecond().getDouble(0), 1e-5); assertEquals(7.0, result.get(3).getSecond().getFloat(0), 1e-5);
assertEquals(2.0, result.get(3).getSecond().getDouble(1), 1e-5); assertEquals(2.0, result.get(3).getSecond().getFloat(1), 1e-5);
assertEquals(8.0, result.get(4).getSecond().getDouble(0), 1e-5); assertEquals(8.0, result.get(4).getSecond().getFloat(0), 1e-5);
assertEquals(1.0, result.get(4).getSecond().getDouble(1), 1e-5); assertEquals(1.0, result.get(4).getSecond().getFloat(1), 1e-5);
assertEquals(9.0, result.get(5).getSecond().getDouble(0), 1e-5); assertEquals(9.0, result.get(5).getSecond().getFloat(0), 1e-5);
assertEquals(6.0, result.get(5).getSecond().getDouble(1), 1e-5); assertEquals(6.0, result.get(5).getSecond().getFloat(1), 1e-5);
} }
@Test @Test
public void testKNN_4() { public void testKNN_4() {
double[] data = new double[]{2, 3}; float[] data = new float[]{2, 3};
val result = kdTree.knn(Nd4j.createFromArray(data), 5.0); List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f);
assertEquals(2.0, result.get(0).getSecond().getDouble(0), 1e-5); assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5);
assertEquals(3.0, result.get(0).getSecond().getDouble(1), 1e-5); assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5);
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5);
assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5); assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5);
assertEquals(4.0, result.get(2).getSecond().getDouble(0), 1e-5); assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5);
assertEquals(7.0, result.get(2).getSecond().getDouble(1), 1e-5); assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5);
} }
@Test @Test
public void testKNN_5() { public void testKNN_5() {
double[] data = new double[]{2, 3}; float[] data = new float[]{2, 3};
val result = kdTree.knn(Nd4j.createFromArray(data), 20.0); List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 20.0f);
assertEquals(2.0, result.get(0).getSecond().getDouble(0), 1e-5); assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5);
assertEquals(3.0, result.get(0).getSecond().getDouble(1), 1e-5); assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5);
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5);
assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5); assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5);
assertEquals(4.0, result.get(2).getSecond().getDouble(0), 1e-5); assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5);
assertEquals(7.0, result.get(2).getSecond().getDouble(1), 1e-5); assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5);
assertEquals(7.0, result.get(3).getSecond().getDouble(0), 1e-5); assertEquals(7.0, result.get(3).getSecond().getFloat(0), 1e-5);
assertEquals(2.0, result.get(3).getSecond().getDouble(1), 1e-5); assertEquals(2.0, result.get(3).getSecond().getFloat(1), 1e-5);
assertEquals(8.0, result.get(4).getSecond().getDouble(0), 1e-5); assertEquals(8.0, result.get(4).getSecond().getFloat(0), 1e-5);
assertEquals(1.0, result.get(4).getSecond().getDouble(1), 1e-5); assertEquals(1.0, result.get(4).getSecond().getFloat(1), 1e-5);
assertEquals(9.0, result.get(5).getSecond().getDouble(0), 1e-5); assertEquals(9.0, result.get(5).getSecond().getFloat(0), 1e-5);
assertEquals(6.0, result.get(5).getSecond().getDouble(1), 1e-5); assertEquals(6.0, result.get(5).getSecond().getFloat(1), 1e-5);
} }
@Test @Test
public void test_KNN_6() { public void test_KNN_6() {
double[] data = new double[]{4, 6}; float[] data = new float[]{4, 6};
val result = kdTree.knn(Nd4j.createFromArray(data), 10.0); List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f);
assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5); assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5);
assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5); assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5);
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5);
@ -318,8 +327,8 @@ public class KDTreeTest extends BaseDL4JTest {
@Test @Test
public void test_KNN_7() { public void test_KNN_7() {
double[] data = new double[]{4, 6}; float[] data = new float[]{4, 6};
val result = kdTree.knn(Nd4j.createFromArray(data), 5.0); List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f);
assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5); assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5);
assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5); assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5);
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5);
@ -334,8 +343,8 @@ public class KDTreeTest extends BaseDL4JTest {
@Test @Test
public void test_KNN_8() { public void test_KNN_8() {
double[] data = new double[]{4, 6}; float[] data = new float[]{4, 6};
val result = kdTree.knn(Nd4j.createFromArray(data), 20.0); List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 20.0f);
assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5); assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5);
assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5); assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5);
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5);
@ -392,12 +401,12 @@ public class KDTreeTest extends BaseDL4JTest {
Duration duration = new Duration(start, end); Duration duration = new Duration(start, end);
System.out.println("Elapsed time for tree construction " + duration.getStandardSeconds() + " " + duration.getMillis()); System.out.println("Elapsed time for tree construction " + duration.getStandardSeconds() + " " + duration.getMillis());
List<Double> pt = new ArrayList(num); List<Float> pt = new ArrayList(num);
for (int k = 0; k < n; k++) { for (int k = 0; k < n; k++) {
pt.add((double)(num / 2)); pt.add((float)(num / 2));
} }
start = System.currentTimeMillis(); start = System.currentTimeMillis();
List<Pair<Double, INDArray>> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Doubles.toArray(pt))), 20.0); List<Pair<Float, INDArray>> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Doubles.toArray(pt))), 20.0f);
end = System.currentTimeMillis(); end = System.currentTimeMillis();
duration = new Duration(start, end); duration = new Duration(start, end);
long elapsed = end - start; long elapsed = end - start;

View File

@ -50,6 +50,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.io.File; import java.io.File;
import java.io.IOException;
import java.util.*; import java.util.*;
import static org.junit.Assert.*; import static org.junit.Assert.*;
@ -816,6 +817,37 @@ public class Word2VecTests extends BaseDL4JTest {
assertEquals(vec1.getWordVectorMatrix("money"), vec2.getWordVectorMatrix("money")); assertEquals(vec1.getWordVectorMatrix("money"), vec2.getWordVectorMatrix("money"));
} }
@Test
public void testWordsNearestSum() throws IOException {
log.info("Load & Vectorize Sentences....");
SentenceIterator iter = new BasicLineIterator(inputFile);
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
log.info("Building model....");
Word2Vec vec = new Word2Vec.Builder()
.minWordFrequency(5)
.iterations(1)
.layerSize(100)
.seed(42)
.windowSize(5)
.iterate(iter)
.tokenizerFactory(t)
.build();
log.info("Fitting Word2Vec model....");
vec.fit();
log.info("Writing word vectors to text file....");
log.info("Closest Words:");
Collection<String> lst = vec.wordsNearestSum("day", 10);
log.info("10 Words closest to 'day': {}", lst);
assertTrue(lst.contains("week"));
assertTrue(lst.contains("night"));
assertTrue(lst.contains("year"));
assertTrue(lst.contains("years"));
assertTrue(lst.contains("time"));
}
private static void printWords(String target, Collection<String> list, Word2Vec vec) { private static void printWords(String target, Collection<String> list, Word2Vec vec) {
System.out.println("Words close to [" + target + "]:"); System.out.println("Words close to [" + target + "]:");
for (String word : list) { for (String word : list) {

View File

@ -104,7 +104,7 @@ public class InMemoryLookupTable<T extends SequenceElement> implements WeightLoo
} }
protected void initAdaGrad() { protected void initAdaGrad() {
int[] shape = new int[] {vocab.numWords() + 1, vectorLength}; long[] shape = new long[] {vocab.numWords() + 1, vectorLength};
int length = ArrayUtil.prod(shape); int length = ArrayUtil.prod(shape);
adaGrad = new AdaGrad(shape, lr.get()); adaGrad = new AdaGrad(shape, lr.get());
adaGrad.setStateViewArray(Nd4j.zeros(shape).reshape(1, length), shape, Nd4j.order(), true); adaGrad.setStateViewArray(Nd4j.zeros(shape).reshape(1, length), shape, Nd4j.order(), true);
@ -124,8 +124,7 @@ public class InMemoryLookupTable<T extends SequenceElement> implements WeightLoo
if (adaGrad == null) if (adaGrad == null)
initAdaGrad(); initAdaGrad();
// FIXME: int cast return adaGrad.getGradient(gradient, column, syn0.shape());
return adaGrad.getGradient(gradient, column, ArrayUtil.toInts(syn0.shape()));
} }
@Override @Override
@ -370,7 +369,6 @@ public class InMemoryLookupTable<T extends SequenceElement> implements WeightLoo
else { else {
nextRandom.set(nextRandom.get() * 25214903917L + 11); nextRandom.set(nextRandom.get() * 25214903917L + 11);
// FIXME: int cast
int idx = (int) Math.abs((int) (nextRandom.get() >> 16) % table.length()); int idx = (int) Math.abs((int) (nextRandom.get() >> 16) % table.length());
target = table.getInt(idx); target = table.getInt(idx);

View File

@ -33,7 +33,6 @@ import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.aggregates.Aggregate; import org.nd4j.linalg.api.ops.aggregates.Aggregate;
import org.nd4j.linalg.api.ops.aggregates.impl.AggregateCBOW;
import org.nd4j.linalg.api.ops.impl.nlp.CbowRound; import org.nd4j.linalg.api.ops.impl.nlp.CbowRound;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.DeviceLocalNDArray; import org.nd4j.linalg.util.DeviceLocalNDArray;

View File

@ -104,11 +104,10 @@ public class GloVe<T extends SequenceElement> implements ElementsLearningAlgorit
weightAdaGrad = new AdaGrad(new int[] {this.vocabCache.numWords() + 1, vectorLength}, learningRate); weightAdaGrad = new AdaGrad(new long[] {this.vocabCache.numWords() + 1, vectorLength}, learningRate);
bias = Nd4j.create(syn0.rows()); bias = Nd4j.create(syn0.rows());
// FIXME: int cast biasAdaGrad = new AdaGrad(bias.shape(), this.learningRate);
biasAdaGrad = new AdaGrad(ArrayUtil.toInts(bias.shape()), this.learningRate);
// maxmemory = Runtime.getRuntime().maxMemory() - (vocabCache.numWords() * vectorLength * 2 * 8); // maxmemory = Runtime.getRuntime().maxMemory() - (vocabCache.numWords() * vectorLength * 2 * 8);
@ -237,15 +236,13 @@ public class GloVe<T extends SequenceElement> implements ElementsLearningAlgorit
private void update(T element1, INDArray wordVector, INDArray contextVector, double gradient) { private void update(T element1, INDArray wordVector, INDArray contextVector, double gradient) {
//gradient for word vectors //gradient for word vectors
INDArray grad1 = contextVector.mul(gradient); INDArray grad1 = contextVector.mul(gradient);
// FIXME: int cast INDArray update = weightAdaGrad.getGradient(grad1, element1.getIndex(), syn0.shape());
INDArray update = weightAdaGrad.getGradient(grad1, element1.getIndex(), ArrayUtil.toInts(syn0.shape()));
//update vector //update vector
wordVector.subi(update); wordVector.subi(update);
double w1Bias = bias.getDouble(element1.getIndex()); double w1Bias = bias.getDouble(element1.getIndex());
// FIXME: int cast double biasGradient = biasAdaGrad.getGradient(gradient, element1.getIndex(), bias.shape());
double biasGradient = biasAdaGrad.getGradient(gradient, element1.getIndex(), ArrayUtil.toInts(bias.shape()));
double update2 = w1Bias - biasGradient; double update2 = w1Bias - biasGradient;
bias.putScalar(element1.getIndex(), update2); bias.putScalar(element1.getIndex(), update2);
} }

View File

@ -351,13 +351,13 @@ public class BasicModelUtils<T extends SequenceElement> implements ModelUtils<T>
if (lookupTable instanceof InMemoryLookupTable) { if (lookupTable instanceof InMemoryLookupTable) {
InMemoryLookupTable l = (InMemoryLookupTable) lookupTable; InMemoryLookupTable l = (InMemoryLookupTable) lookupTable;
INDArray syn0 = l.getSyn0(); INDArray syn0 = l.getSyn0();
INDArray weights = syn0.norm2(0).rdivi(1).muli(words); INDArray temp = syn0.norm2(0).rdivi(1).reshape(words.shape());
INDArray weights = temp.muli(words);
INDArray distances = syn0.mulRowVector(weights).sum(1); INDArray distances = syn0.mulRowVector(weights).sum(1);
INDArray[] sorted = Nd4j.sortWithIndices(distances, 0, false); INDArray[] sorted = Nd4j.sortWithIndices(distances, 0, false);
INDArray sort = sorted[0]; INDArray sort = sorted[0];
List<String> ret = new ArrayList<>(); List<String> ret = new ArrayList<>();
// FIXME: int cast
if (top > sort.length()) if (top > sort.length())
top = (int) sort.length(); top = (int) sort.length();
//there will be a redundant word //there will be a redundant word

View File

@ -72,7 +72,7 @@ public class GloveWeightLookupTable<T extends SequenceElement> extends InMemoryL
putVector(Word2Vec.DEFAULT_UNK, randUnk); putVector(Word2Vec.DEFAULT_UNK, randUnk);
} }
if (weightAdaGrad == null || reset) { if (weightAdaGrad == null || reset) {
weightAdaGrad = new AdaGrad(new int[] {vocab.numWords() + 1, vectorLength}, lr.get()); weightAdaGrad = new AdaGrad(new long[]{vocab.numWords() + 1, vectorLength}, lr.get());
} }
@ -81,7 +81,7 @@ public class GloveWeightLookupTable<T extends SequenceElement> extends InMemoryL
bias = Nd4j.create(syn0.rows()); bias = Nd4j.create(syn0.rows());
if (biasAdaGrad == null || reset) { if (biasAdaGrad == null || reset) {
biasAdaGrad = new AdaGrad(ArrayUtil.toInts(bias.shape()), lr.get()); biasAdaGrad = new AdaGrad(bias.shape(), lr.get());
} }
@ -140,13 +140,13 @@ public class GloveWeightLookupTable<T extends SequenceElement> extends InMemoryL
private void update(T w1, INDArray wordVector, INDArray contextVector, double gradient) { private void update(T w1, INDArray wordVector, INDArray contextVector, double gradient) {
//gradient for word vectors //gradient for word vectors
INDArray grad1 = contextVector.mul(gradient); INDArray grad1 = contextVector.mul(gradient);
INDArray update = weightAdaGrad.getGradient(grad1, w1.getIndex(), ArrayUtil.toInts(syn0.shape())); INDArray update = weightAdaGrad.getGradient(grad1, w1.getIndex(), syn0.shape());
//update vector //update vector
wordVector.subi(update); wordVector.subi(update);
double w1Bias = bias.getDouble(w1.getIndex()); double w1Bias = bias.getDouble(w1.getIndex());
double biasGradient = biasAdaGrad.getGradient(gradient, w1.getIndex(), ArrayUtil.toInts(bias.shape())); double biasGradient = biasAdaGrad.getGradient(gradient, w1.getIndex(), bias.shape());
double update2 = w1Bias - biasGradient; double update2 = w1Bias - biasGradient;
bias.putScalar(w1.getIndex(), update2); bias.putScalar(w1.getIndex(), update2);
} }

View File

@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.Model;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.function.Consumer; import org.nd4j.linalg.function.Consumer;
import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT; import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
@ -293,7 +294,8 @@ public class GradientCheckUtil {
ss = n; ss = n;
} }
// FIXME: int cast if (ss > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
stepSizeForParam.put(paramNames.get(i), (int) ss); stepSizeForParam.put(paramNames.get(i), (int) ss);
} }
} }

View File

@ -140,10 +140,9 @@ public class ElementWiseVertex extends GraphVertex {
//CNN inputs... also check that the channels, width and heights match: //CNN inputs... also check that the channels, width and heights match:
InputType.InputTypeConvolutional firstConv = (InputType.InputTypeConvolutional) first; InputType.InputTypeConvolutional firstConv = (InputType.InputTypeConvolutional) first;
// FIXME: int cast val fd = firstConv.getChannels();
val fd = (int) firstConv.getChannels(); val fw = firstConv.getWidth();
val fw = (int) firstConv.getWidth(); val fh = firstConv.getHeight();
val fh = (int) firstConv.getHeight();
for (int i = 1; i < vertexInputs.length; i++) { for (int i = 1; i < vertexInputs.length; i++) {
if (vertexInputs[i].getType() != InputType.Type.CNN) { if (vertexInputs[i].getType() != InputType.Type.CNN) {
@ -155,10 +154,9 @@ public class ElementWiseVertex extends GraphVertex {
InputType.InputTypeConvolutional otherConv = (InputType.InputTypeConvolutional) vertexInputs[i]; InputType.InputTypeConvolutional otherConv = (InputType.InputTypeConvolutional) vertexInputs[i];
// FIXME: int cast val od = otherConv.getChannels();
val od = (int) otherConv.getChannels(); val ow = otherConv.getWidth();
val ow = (int) otherConv.getWidth(); val oh = otherConv.getHeight();
val oh = (int) otherConv.getHeight();
if (fd != od || fw != ow || fh != oh) { if (fd != od || fw != ow || fh != oh) {
throw new InvalidInputTypeException( throw new InvalidInputTypeException(

View File

@ -94,13 +94,12 @@ public class MergeVertex extends GraphVertex {
// CNN3D inputs: check that the channels, width and height match: // CNN3D inputs: check that the channels, width and height match:
InputType.InputTypeConvolutional3D firstConv = (InputType.InputTypeConvolutional3D) first; InputType.InputTypeConvolutional3D firstConv = (InputType.InputTypeConvolutional3D) first;
// FIXME: int cast val fd = firstConv.getDepth();
val fd = (int) firstConv.getDepth(); val fw = firstConv.getWidth();
val fw = (int) firstConv.getWidth(); val fh = firstConv.getHeight();
val fh = (int) firstConv.getHeight(); val fc = firstConv.getChannels();
val fc = (int) firstConv.getChannels();
int depthSum = fc; long depthSum = fc;
InputType.InputTypeConvolutional3D otherConv = null; InputType.InputTypeConvolutional3D otherConv = null;
for (int i = 1; i < vertexInputs.length; i++) { for (int i = 1; i < vertexInputs.length; i++) {
if (vertexInputs[i].getType() != InputType.Type.CNN3D) { if (vertexInputs[i].getType() != InputType.Type.CNN3D) {
@ -109,10 +108,10 @@ public class MergeVertex extends GraphVertex {
} }
otherConv = (InputType.InputTypeConvolutional3D) vertexInputs[i]; otherConv = (InputType.InputTypeConvolutional3D) vertexInputs[i];
val od = (int) otherConv.getDepth(); val od = otherConv.getDepth();
val ow = (int) otherConv.getWidth(); val ow = otherConv.getWidth();
val oh = (int) otherConv.getHeight(); val oh = otherConv.getHeight();
val oc = (int) otherConv.getChannels(); val oc = otherConv.getChannels();
if (fd != od || fw != ow || fh != oh) { if (fd != od || fw != ow || fh != oh) {
throw new InvalidInputTypeException("Invalid input: MergeVertex cannot merge CNN3D activations of different width/heights:" + "first [channels,width,height] = [" + fd + "," + fw + "," + fh throw new InvalidInputTypeException("Invalid input: MergeVertex cannot merge CNN3D activations of different width/heights:" + "first [channels,width,height] = [" + fd + "," + fw + "," + fh
@ -177,12 +176,11 @@ public class MergeVertex extends GraphVertex {
//CNN inputs... also check that the channels, width and heights match: //CNN inputs... also check that the channels, width and heights match:
InputType.InputTypeConvolutional firstConv = (InputType.InputTypeConvolutional) first; InputType.InputTypeConvolutional firstConv = (InputType.InputTypeConvolutional) first;
// FIXME: int cast val fd = firstConv.getChannels();
val fd = (int) firstConv.getChannels(); val fw = firstConv.getWidth();
val fw = (int) firstConv.getWidth(); val fh = firstConv.getHeight();
val fh = (int) firstConv.getHeight();
int depthSum = fd; long depthSum = fd;
for (int i = 1; i < vertexInputs.length; i++) { for (int i = 1; i < vertexInputs.length; i++) {
if (vertexInputs[i].getType() != InputType.Type.CNN) { if (vertexInputs[i].getType() != InputType.Type.CNN) {
@ -194,10 +192,9 @@ public class MergeVertex extends GraphVertex {
InputType.InputTypeConvolutional otherConv = (InputType.InputTypeConvolutional) vertexInputs[i]; InputType.InputTypeConvolutional otherConv = (InputType.InputTypeConvolutional) vertexInputs[i];
// FIXME: int cast val od = otherConv.getChannels();
val od = (int) otherConv.getChannels(); val ow = otherConv.getWidth();
val ow = (int) otherConv.getWidth(); val oh = otherConv.getHeight();
val oh = (int) otherConv.getHeight();
if (fw != ow || fh != oh) { if (fw != ow || fh != oh) {
throw new InvalidInputTypeException( throw new InvalidInputTypeException(

View File

@ -131,12 +131,11 @@ public class PoolHelperVertex extends GraphVertex {
//CNN inputs... also check that the channels, width and heights match: //CNN inputs... also check that the channels, width and heights match:
InputType.InputTypeConvolutional firstConv = (InputType.InputTypeConvolutional) first; InputType.InputTypeConvolutional firstConv = (InputType.InputTypeConvolutional) first;
// FIXME: int cast val fd = firstConv.getChannels();
val fd = (int) firstConv.getChannels(); val fw = firstConv.getWidth();
val fw = (int) firstConv.getWidth(); val fh = firstConv.getHeight();
val fh = (int) firstConv.getHeight();
int depthSum = fd; long depthSum = fd;
for (int i = 1; i < vertexInputs.length; i++) { for (int i = 1; i < vertexInputs.length; i++) {
if (vertexInputs[i].getType() != InputType.Type.CNN) { if (vertexInputs[i].getType() != InputType.Type.CNN) {
@ -148,10 +147,9 @@ public class PoolHelperVertex extends GraphVertex {
InputType.InputTypeConvolutional otherConv = (InputType.InputTypeConvolutional) vertexInputs[i]; InputType.InputTypeConvolutional otherConv = (InputType.InputTypeConvolutional) vertexInputs[i];
// FIXME: int cast long od = otherConv.getChannels();
int od = (int) otherConv.getChannels(); long ow = otherConv.getWidth();
int ow = (int) otherConv.getWidth(); long oh = otherConv.getHeight();
int oh = (int) otherConv.getHeight();
if (fw != ow || fh != oh) { if (fw != ow || fh != oh) {
throw new InvalidInputTypeException( throw new InvalidInputTypeException(

View File

@ -150,12 +150,11 @@ public class UnstackVertex extends GraphVertex {
//CNN inputs... also check that the channels, width and heights match: //CNN inputs... also check that the channels, width and heights match:
InputType.InputTypeConvolutional firstConv = (InputType.InputTypeConvolutional) first; InputType.InputTypeConvolutional firstConv = (InputType.InputTypeConvolutional) first;
// FIXME: int cast val fd = firstConv.getChannels();
val fd = (int) firstConv.getChannels(); val fw = firstConv.getWidth();
val fw = (int) firstConv.getWidth(); val fh = firstConv.getHeight();
val fh = (int) firstConv.getHeight();
int depthSum = fd; long depthSum = fd;
for (int i = 1; i < vertexInputs.length; i++) { for (int i = 1; i < vertexInputs.length; i++) {
if (vertexInputs[i].getType() != InputType.Type.CNN) { if (vertexInputs[i].getType() != InputType.Type.CNN) {
@ -167,10 +166,9 @@ public class UnstackVertex extends GraphVertex {
InputType.InputTypeConvolutional otherConv = (InputType.InputTypeConvolutional) vertexInputs[i]; InputType.InputTypeConvolutional otherConv = (InputType.InputTypeConvolutional) vertexInputs[i];
// FIXME: int cast val od = otherConv.getChannels();
val od = (int) otherConv.getChannels(); val ow = otherConv.getWidth();
val ow = (int) otherConv.getWidth(); val oh = otherConv.getHeight();
val oh = (int) otherConv.getHeight();
if (fw != ow || fh != oh) { if (fw != ow || fh != oh) {
throw new InvalidInputTypeException( throw new InvalidInputTypeException(

View File

@ -402,18 +402,17 @@ public abstract class InputType implements Serializable {
//Note: ConvolutionalFlat and FeedForward look identical... but either should work OK if using something //Note: ConvolutionalFlat and FeedForward look identical... but either should work OK if using something
// like FeedForwardToCnnPreProcessor // like FeedForwardToCnnPreProcessor
// FIXME: int cast
switch (inputArray.rank()) { switch (inputArray.rank()) {
case 2: case 2:
return InputType.feedForward((int) inputArray.size(1)); return InputType.feedForward(inputArray.size(1));
case 3: case 3:
return InputType.recurrent((int) inputArray.size(1), (int) inputArray.size(2)); return InputType.recurrent(inputArray.size(1), (int) inputArray.size(2));
case 4: case 4:
//Order: [minibatch, channels, height, width] -> [h, w, c] //Order: [minibatch, channels, height, width] -> [h, w, c]
return InputType.convolutional((int) inputArray.size(2), (int) inputArray.size(3), (int) inputArray.size(1)); return InputType.convolutional(inputArray.size(2), (int) inputArray.size(3), (int) inputArray.size(1));
case 5: case 5:
//Order: [minibatch, channels, depth, height, width] -> [d, h, w, c] //Order: [minibatch, channels, depth, height, width] -> [d, h, w, c]
return InputType.convolutional3D((int) inputArray.size(2), (int) inputArray.size(3), return InputType.convolutional3D(inputArray.size(2), (int) inputArray.size(3),
(int) inputArray.size(4), (int) inputArray.size(1)); (int) inputArray.size(4), (int) inputArray.size(1));
default: default:
throw new IllegalArgumentException( throw new IllegalArgumentException(

View File

@ -152,17 +152,18 @@ public class Cnn3DLossLayer extends FeedForwardLayer {
} }
@Override @Override
public void setNIn(int nIn){ public void setNIn(long nIn){
throw new UnsupportedOperationException( throw new UnsupportedOperationException(
"Cnn3DLossLayer has no parameters, thus nIn will always equal nOut."); "Cnn3DLossLayer has no parameters, thus nIn will always equal nOut.");
} }
@Override @Override
public void setNOut(int nOut){ public void setNOut(long nOut){
throw new UnsupportedOperationException( throw new UnsupportedOperationException(
"Cnn3DLossLayer has no parameters, thus nIn will always equal nOut."); "Cnn3DLossLayer has no parameters, thus nIn will always equal nOut.");
} }
@Override @Override
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public Cnn3DLossLayer build() { public Cnn3DLossLayer build() {

View File

@ -145,13 +145,13 @@ public class CnnLossLayer extends FeedForwardLayer {
} }
@Override @Override
public void setNIn(int nIn){ public void setNIn(long nIn){
throw new UnsupportedOperationException( throw new UnsupportedOperationException(
"This layer has no parameters, thus nIn will always equal nOut."); "This layer has no parameters, thus nIn will always equal nOut.");
} }
@Override @Override
public void setNOut(int nOut){ public void setNOut(long nOut){
throw new UnsupportedOperationException( throw new UnsupportedOperationException(
"This layer has no parameters, thus nIn will always equal nOut."); "This layer has no parameters, thus nIn will always equal nOut.");
} }

View File

@ -88,7 +88,7 @@ public class Convolution1DLayer extends ConvolutionLayer {
//Probably: user did InputType.recurrent(x) without specifying sequence length //Probably: user did InputType.recurrent(x) without specifying sequence length
outLength = -1; outLength = -1;
} else { } else {
outLength = Convolution1DUtils.getOutputSize((int) inputTsLength, kernelSize[0], stride[0], padding[0], outLength = Convolution1DUtils.getOutputSize(inputTsLength, kernelSize[0], stride[0], padding[0],
convolutionMode, dilation[0]); convolutionMode, dilation[0]);
} }
return InputType.recurrent(nOut, outLength); return InputType.recurrent(nOut, outLength);

View File

@ -117,14 +117,14 @@ public abstract class FeedForwardLayer extends BaseLayer {
* this is the input channels, otherwise is the previous layer size. * this is the input channels, otherwise is the previous layer size.
* *
*/ */
protected int nIn = 0; protected long nIn = 0;
/** /**
* Number of inputs for the layer (usually the size of the last layer). <br> Note that for Convolutional layers, * Number of inputs for the layer (usually the size of the last layer). <br> Note that for Convolutional layers,
* this is the input channels, otherwise is the previous layer size. * this is the input channels, otherwise is the previous layer size.
* *
*/ */
protected int nOut = 0; protected long nOut = 0;
/** /**
* Number of inputs for the layer (usually the size of the last layer). <br> Note that for Convolutional layers, * Number of inputs for the layer (usually the size of the last layer). <br> Note that for Convolutional layers,
@ -144,8 +144,7 @@ public abstract class FeedForwardLayer extends BaseLayer {
* @param nIn Number of inputs for the layer * @param nIn Number of inputs for the layer
*/ */
public T nIn(long nIn) { public T nIn(long nIn) {
// FIXME: int cast this.setNIn(nIn);
this.setNIn((int) nIn);
return (T) this; return (T) this;
} }

View File

@ -41,12 +41,9 @@ public class InputTypeUtil {
Class<?> layerClass) { Class<?> layerClass) {
InputType.InputTypeConvolutional i = (InputType.InputTypeConvolutional) inputType; InputType.InputTypeConvolutional i = (InputType.InputTypeConvolutional) inputType;
// FIXME: int cast val hIn = i.getHeight();
val hIn = (int) i.getHeight(); val wIn = i.getWidth();
val wIn = (int) i.getWidth();
val inHeight = (int) i.getHeight();
val inWidth = (int) i.getWidth();
int padH = (padding == null ? 0 : padding[0]); //May be null for ConvolutionMode.Same int padH = (padding == null ? 0 : padding[0]); //May be null for ConvolutionMode.Same
int padW = (padding == null ? 0 : padding[1]); int padW = (padding == null ? 0 : padding[1]);
int kH = kernelSize[0]; int kH = kernelSize[0];
@ -69,13 +66,13 @@ public class InputTypeUtil {
} }
if (convolutionMode == ConvolutionMode.Same) { if (convolutionMode == ConvolutionMode.Same) {
int hOut = stride[0] * hIn; long hOut = stride[0] * hIn;
int wOut = stride[1] * wIn; long wOut = stride[1] * wIn;
return InputType.convolutional(hOut, wOut, outputDepth); return InputType.convolutional(hOut, wOut, outputDepth);
} }
int hOut = sH * (hIn - 1) + kH - 2 * padH; long hOut = sH * (hIn - 1) + kH - 2 * padH;
int wOut = sW * (wIn - 1) + kW - 2 * padW; long wOut = sW * (wIn - 1) + kW - 2 * padW;
return InputType.convolutional(hOut, wOut, outputDepth); return InputType.convolutional(hOut, wOut, outputDepth);
} }
@ -91,10 +88,9 @@ public class InputTypeUtil {
InputType.InputTypeConvolutional3D i = (InputType.InputTypeConvolutional3D) inputType; InputType.InputTypeConvolutional3D i = (InputType.InputTypeConvolutional3D) inputType;
// FIXME: int cast long inDepth = i.getDepth();
val inDepth = (int) i.getDepth(); long inHeight = i.getHeight();
val inHeight = (int) i.getHeight(); long inWidth = i.getWidth();
val inWidth = (int) i.getWidth();
int padD = (padding == null ? 0 : padding[0]); int padD = (padding == null ? 0 : padding[0]);
int padH = (padding == null ? 0 : padding[1]); int padH = (padding == null ? 0 : padding[1]);
@ -211,9 +207,9 @@ public class InputTypeUtil {
return InputType.convolutional3D(outD, outH, outW, outputChannels); return InputType.convolutional3D(outD, outH, outW, outputChannels);
} }
int dOut = (inDepth - kD + 2 * padD) / sD + 1; long dOut = (inDepth - kD + 2 * padD) / sD + 1;
int hOut = (inHeight - kH + 2 * padH) / sH + 1; long hOut = (inHeight - kH + 2 * padH) / sH + 1;
int wOut = (inWidth - kW + 2 * padW) / sW + 1; long wOut = (inWidth - kW + 2 * padW) / sW + 1;
return InputType.convolutional3D(dOut, hOut, wOut, outputChannels); return InputType.convolutional3D(dOut, hOut, wOut, outputChannels);
} }
@ -296,9 +292,8 @@ public class InputTypeUtil {
InputType.InputTypeConvolutional i = (InputType.InputTypeConvolutional) inputType; InputType.InputTypeConvolutional i = (InputType.InputTypeConvolutional) inputType;
// FIXME: int cast long inHeight = i.getHeight();
val inHeight = (int) i.getHeight(); long inWidth = i.getWidth();
val inWidth = (int) i.getWidth();
int padH = (padding == null ? 0 : padding[0]); //May be null for ConvolutionMode.Same int padH = (padding == null ? 0 : padding[0]); //May be null for ConvolutionMode.Same
int padW = (padding == null ? 0 : padding[1]); int padW = (padding == null ? 0 : padding[1]);
int kH = kernelSize[0]; int kH = kernelSize[0];
@ -379,8 +374,8 @@ public class InputTypeUtil {
return InputType.convolutional(outH, outW, outputDepth); return InputType.convolutional(outH, outW, outputDepth);
} }
int hOut = (inHeight - kH + 2 * padH) / sH + 1; long hOut = (inHeight - kH + 2 * padH) / sH + 1;
int wOut = (inWidth - kW + 2 * padW) / sW + 1; long wOut = (inWidth - kW + 2 * padW) / sW + 1;
return InputType.convolutional(hOut, wOut, outputDepth); return InputType.convolutional(hOut, wOut, outputDepth);
} }

View File

@ -145,7 +145,7 @@ public class LocallyConnected1D extends SameDiffLayer {
val weightsShape = new long[] {outputSize, featureDim, nOut}; val weightsShape = new long[] {outputSize, featureDim, nOut};
params.addWeightParam(ConvolutionParamInitializer.WEIGHT_KEY, weightsShape); params.addWeightParam(ConvolutionParamInitializer.WEIGHT_KEY, weightsShape);
if (hasBias) { if (hasBias) {
val biasShape = new long[] {1, nOut}; val biasShape = new long[] {nOut};
params.addBiasParam(ConvolutionParamInitializer.BIAS_KEY, biasShape); params.addBiasParam(ConvolutionParamInitializer.BIAS_KEY, biasShape);
} }
} }
@ -200,7 +200,7 @@ public class LocallyConnected1D extends SameDiffLayer {
if (hasBias) { if (hasBias) {
SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY); SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY);
SDVariable biasAddedResult = sameDiff.nn().biasAdd(result, b); SDVariable biasAddedResult = sameDiff.nn().biasAdd(result, b, true);
return activation.asSameDiff("out", sameDiff, biasAddedResult); return activation.asSameDiff("out", sameDiff, biasAddedResult);
} else { } else {
return activation.asSameDiff("out", sameDiff, result); return activation.asSameDiff("out", sameDiff, result);

View File

@ -145,7 +145,7 @@ public class LocallyConnected2D extends SameDiffLayer {
val weightsShape = new long[] {outputSize[0] * outputSize[1], featureDim, nOut}; val weightsShape = new long[] {outputSize[0] * outputSize[1], featureDim, nOut};
params.addWeightParam(ConvolutionParamInitializer.WEIGHT_KEY, weightsShape); params.addWeightParam(ConvolutionParamInitializer.WEIGHT_KEY, weightsShape);
if (hasBias) { if (hasBias) {
val biasShape = new long[] {1, nOut}; val biasShape = new long[] {nOut};
params.addBiasParam(ConvolutionParamInitializer.BIAS_KEY, biasShape); params.addBiasParam(ConvolutionParamInitializer.BIAS_KEY, biasShape);
} }
} }
@ -211,7 +211,7 @@ public class LocallyConnected2D extends SameDiffLayer {
if (hasBias) { if (hasBias) {
SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY); SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY);
SDVariable biasAddedResult = sameDiff.nn().biasAdd(permutedResult, b); SDVariable biasAddedResult = sameDiff.nn().biasAdd(permutedResult, b, true);
return activation.asSameDiff("out", sameDiff, biasAddedResult); return activation.asSameDiff("out", sameDiff, biasAddedResult);
} else { } else {
return activation.asSameDiff("out", sameDiff, permutedResult); return activation.asSameDiff("out", sameDiff, permutedResult);

View File

@ -142,13 +142,13 @@ public class RnnLossLayer extends FeedForwardLayer {
} }
@Override @Override
public void setNIn(int nIn){ public void setNIn(long nIn){
throw new UnsupportedOperationException( throw new UnsupportedOperationException(
"This layer has no parameters, thus nIn will always equal nOut."); "This layer has no parameters, thus nIn will always equal nOut.");
} }
@Override @Override
public void setNOut(int nOut){ public void setNOut(long nOut){
throw new UnsupportedOperationException( throw new UnsupportedOperationException(
"This layer has no parameters, thus nIn will always equal nOut."); "This layer has no parameters, thus nIn will always equal nOut.");
} }

View File

@ -82,12 +82,12 @@ public class Subsampling1DLayer extends SubsamplingLayer {
} }
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType; InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
long inputTsLength = r.getTimeSeriesLength(); long inputTsLength = r.getTimeSeriesLength();
int outLength; long outLength;
if (inputTsLength < 0) { if (inputTsLength < 0) {
//Probably: user did InputType.recurrent(x) without specifying sequence length //Probably: user did InputType.recurrent(x) without specifying sequence length
outLength = -1; outLength = -1;
} else { } else {
outLength = Convolution1DUtils.getOutputSize((int) inputTsLength, kernelSize[0], stride[0], padding[0], outLength = Convolution1DUtils.getOutputSize(inputTsLength, kernelSize[0], stride[0], padding[0],
convolutionMode, dilation[0]); convolutionMode, dilation[0]);
} }
return InputType.recurrent(r.getSize(), outLength); return InputType.recurrent(r.getSize(), outLength);

View File

@ -32,6 +32,7 @@ import org.deeplearning4j.util.ValidationUtils;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.linalg.learning.regularization.Regularization;
import java.util.Collection; import java.util.Collection;
@ -138,9 +139,11 @@ public class Subsampling3DLayer extends NoParamLayer {
+ "\"): Expected CNN input, got " + inputType); + "\"): Expected CNN input, got " + inputType);
} }
// FIXME: int cast long inChannels = ((InputType.InputTypeConvolutional3D) inputType).getChannels();
if (inChannels > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
return InputTypeUtil.getOutputTypeCnn3DLayers(inputType, kernelSize, stride, padding, new int[] {1, 1, 1}, // no dilation return InputTypeUtil.getOutputTypeCnn3DLayers(inputType, kernelSize, stride, padding, new int[] {1, 1, 1}, // no dilation
convolutionMode, (int) ((InputType.InputTypeConvolutional3D) inputType).getChannels(), convolutionMode, (int) inChannels,
layerIndex, getLayerName(), Subsampling3DLayer.class); layerIndex, getLayerName(), Subsampling3DLayer.class);
} }

View File

@ -83,11 +83,10 @@ public class Upsampling3D extends BaseUpsamplingLayer {
} }
InputType.InputTypeConvolutional3D i = (InputType.InputTypeConvolutional3D) inputType; InputType.InputTypeConvolutional3D i = (InputType.InputTypeConvolutional3D) inputType;
// FIXME: int cast long inHeight = (int) i.getHeight();
int inHeight = (int) i.getHeight(); long inWidth = (int) i.getWidth();
int inWidth = (int) i.getWidth(); long inDepth = (int) i.getDepth();
int inDepth = (int) i.getDepth(); long inChannels = (int) i.getChannels();
int inChannels = (int) i.getChannels();
return InputType.convolutional3D(size[0] * inDepth, size[1] * inHeight, size[2] * inWidth, inChannels); return InputType.convolutional3D(size[0] * inDepth, size[1] * inHeight, size[2] * inWidth, inChannels);
} }

View File

@ -65,7 +65,7 @@ public abstract class SameDiffLambdaVertex extends SameDiffVertex {
defineVertex(temp, tempInputs); defineVertex(temp, tempInputs);
List<String> list = new ArrayList<>(); List<String> list = new ArrayList<>();
for (Integer i : tempInputs.map.keySet()) { for (Integer i : tempInputs.map.keySet()) {
list.add(tempInputs.map.get(i).getVarName()); list.add(tempInputs.map.get(i).name());
} }
params.defineInputs(list.toArray(new String[list.size()])); params.defineInputs(list.toArray(new String[list.size()]));
} }

View File

@ -259,7 +259,7 @@ public class OCNNOutputLayer extends BaseOutputLayer {
} }
@Override @Override
public void setNOut(int nOut){ public void setNOut(long nOut){
throw new UnsupportedOperationException( throw new UnsupportedOperationException(
"Unable to specify number of outputs with ocnn. Outputs are fixed to 1."); "Unable to specify number of outputs with ocnn. Outputs are fixed to 1.");
} }

View File

@ -79,6 +79,7 @@ import org.nd4j.linalg.dataset.api.DataSetUtil;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.heartbeat.Heartbeat; import org.nd4j.linalg.heartbeat.Heartbeat;
import org.nd4j.linalg.heartbeat.reports.Environment; import org.nd4j.linalg.heartbeat.reports.Environment;
@ -3329,7 +3330,6 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
//In 99+% of cases, the input and labels dimension 0 size should be identical //In 99+% of cases, the input and labels dimension 0 size should be identical
//The only real exceptions: space to batch, and batch to space layers //The only real exceptions: space to batch, and batch to space layers
//In those cases, we should base it on the labels size, as this impacts gradient calculation //In those cases, we should base it on the labels size, as this impacts gradient calculation
// FIXME: int cast
return labels == null || labels[0] == null ? (int) inputs[0].size(0) : (int)labels[0].size(0); return labels == null || labels[0] == null ? (int) inputs[0].size(0) : (int)labels[0].size(0);
} }
@ -3653,7 +3653,8 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
if (endTimeIdx > timeSeriesLength) if (endTimeIdx > timeSeriesLength)
endTimeIdx = timeSeriesLength; endTimeIdx = timeSeriesLength;
// FIXME: int cast if (startTimeIdx > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
List<INDArray[]> list = getSubsetsForTbptt((int) startTimeIdx, endTimeIdx, inputs, labels, featureMasks, labelMasks); List<INDArray[]> list = getSubsetsForTbptt((int) startTimeIdx, endTimeIdx, inputs, labels, featureMasks, labelMasks);
setInputs(list.get(0)); setInputs(list.get(0));
@ -3799,7 +3800,8 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
} }
} }
// FIXME: int cast if (minibatchSize > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
Pair<INDArray, MaskState> outPair = Pair<INDArray, MaskState> outPair =
current.feedForwardMaskArrays(inputMasks, maskState, (int)minibatchSize); current.feedForwardMaskArrays(inputMasks, maskState, (int)minibatchSize);
map.put(topologicalOrder[i], outPair); map.put(topologicalOrder[i], outPair);
@ -4664,7 +4666,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
* @param layer Index of the layer to get the size of. Must be in range 0 to nLayers-1 inclusive * @param layer Index of the layer to get the size of. Must be in range 0 to nLayers-1 inclusive
* @return Size of the layer * @return Size of the layer
*/ */
public int layerSize(int layer) { public long layerSize(int layer) {
if (layer < 0 || layer > layers.length) { if (layer < 0 || layer > layers.length) {
throw new IllegalArgumentException("Invalid layer index: " + layer + ". Layer index must be between 0 and " throw new IllegalArgumentException("Invalid layer index: " + layer + ". Layer index must be between 0 and "
+ (layers.length - 1) + " inclusive"); + (layers.length - 1) + " inclusive");
@ -4683,7 +4685,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
* @param layer Index of the layer to get the size of. Must be in range 0 to nLayers-1 inclusive * @param layer Index of the layer to get the size of. Must be in range 0 to nLayers-1 inclusive
* @return Size of the layer * @return Size of the layer
*/ */
public int layerInputSize(int layer) { public long layerInputSize(int layer) {
if (layer < 0 || layer > layers.length) { if (layer < 0 || layer > layers.length) {
throw new IllegalArgumentException("Invalid layer index: " + layer + ". Layer index must be between 0 and " throw new IllegalArgumentException("Invalid layer index: " + layer + ". Layer index must be between 0 and "
+ (layers.length - 1) + " inclusive"); + (layers.length - 1) + " inclusive");
@ -4701,7 +4703,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
* @param layerName Name of the layer to get the size of * @param layerName Name of the layer to get the size of
* @return Size of the layer * @return Size of the layer
*/ */
public int layerSize(String layerName) { public long layerSize(String layerName) {
Layer l = getLayer(layerName); Layer l = getLayer(layerName);
if(l == null){ if(l == null){
throw new IllegalArgumentException("No layer with name \"" + layerName + "\" exists"); throw new IllegalArgumentException("No layer with name \"" + layerName + "\" exists");
@ -4712,8 +4714,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
} }
FeedForwardLayer ffl = (FeedForwardLayer) conf; FeedForwardLayer ffl = (FeedForwardLayer) conf;
// FIXME: int cast return ffl.getNOut();
return (int) ffl.getNOut();
} }
/** /**
@ -4727,7 +4728,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
* @param layerName Name of the layer to get the size of * @param layerName Name of the layer to get the size of
* @return Size of the layer * @return Size of the layer
*/ */
public int layerInputSize(String layerName) { public long layerInputSize(String layerName) {
Layer l = getLayer(layerName); Layer l = getLayer(layerName);
if(l == null){ if(l == null){
throw new IllegalArgumentException("No layer with name \"" + layerName + "\" exists"); throw new IllegalArgumentException("No layer with name \"" + layerName + "\" exists");
@ -4738,8 +4739,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
} }
FeedForwardLayer ffl = (FeedForwardLayer) conf; FeedForwardLayer ffl = (FeedForwardLayer) conf;
// FIXME: int cast return ffl.getNIn();
return (int) ffl.getNIn();
} }
/** /**

View File

@ -114,7 +114,7 @@ public class MergeVertex extends BaseGraphVertex {
} }
try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATIONS)){ try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATIONS)){
return Nd4j.hstack(in); return Nd4j.concat(1, in);
} }
} }

View File

@ -43,10 +43,10 @@ import java.util.Arrays;
* @author Justin Long (crockpotveggies) * @author Justin Long (crockpotveggies)
*/ */
public class UnstackVertex extends BaseGraphVertex { public class UnstackVertex extends BaseGraphVertex {
private int from; private long from;
private int stackSize; private int stackSize;
private long forwardShape[]; private long forwardShape[];
private int step; private long step;
public UnstackVertex(ComputationGraph graph, String name, int vertexIndex, int from, int stackSize, DataType dataType) { public UnstackVertex(ComputationGraph graph, String name, int vertexIndex, int from, int stackSize, DataType dataType) {
this(graph, name, vertexIndex, null, null, from, stackSize, dataType); this(graph, name, vertexIndex, null, null, from, stackSize, dataType);
@ -77,10 +77,9 @@ public class UnstackVertex extends BaseGraphVertex {
// once we know the inputs, save the shape and interval size for doBackward // once we know the inputs, save the shape and interval size for doBackward
this.forwardShape = Arrays.copyOf(inputs[0].shape(), inputs[0].rank()); this.forwardShape = Arrays.copyOf(inputs[0].shape(), inputs[0].rank());
// FIXME: int cast this.step = inputs[0].size(0) / stackSize;
this.step = (int) inputs[0].size(0) / stackSize; long start = from * step;
int start = from * step; long end = (from + 1) * step;
int end = (from + 1) * step;
INDArray ret; INDArray ret;
switch (inputs[0].rank()) { //TODO remove the dups here if/when possible (gradient checks must pass) switch (inputs[0].rank()) { //TODO remove the dups here if/when possible (gradient checks must pass)
@ -108,8 +107,8 @@ public class UnstackVertex extends BaseGraphVertex {
throw new IllegalStateException("Cannot do backward pass: error not set"); throw new IllegalStateException("Cannot do backward pass: error not set");
INDArray out = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, inputs[0].dataType(), forwardShape); INDArray out = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, inputs[0].dataType(), forwardShape);
int start = from * step; long start = from * step;
int end = (from + 1) * step; long end = (from + 1) * step;
switch (forwardShape.length) { switch (forwardShape.length) {
case 2: case 2:
@ -154,8 +153,8 @@ public class UnstackVertex extends BaseGraphVertex {
} }
//Mask arrays are either 1d (column vector) or 2d... //Mask arrays are either 1d (column vector) or 2d...
int start = from * minibatchSize; long start = from * minibatchSize;
int end = (from + 1) * minibatchSize; long end = (from + 1) * minibatchSize;
INDArray outMask = maskArrays[0].get(NDArrayIndex.interval(start, end), NDArrayIndex.all()); INDArray outMask = maskArrays[0].get(NDArrayIndex.interval(start, end), NDArrayIndex.all());
return new Pair<>(outMask, currentMaskState); return new Pair<>(outMask, currentMaskState);
} }

View File

@ -87,9 +87,8 @@ public class LastTimeStepVertex extends BaseGraphVertex {
INDArray out; INDArray out;
if (mask == null) { if (mask == null) {
// FIXME: int cast
//No mask array -> extract same (last) column for all //No mask array -> extract same (last) column for all
int lastTS = (int) inputs[0].size(2) - 1; long lastTS = inputs[0].size(2) - 1;
out = inputs[0].get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(lastTS)); out = inputs[0].get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(lastTS));
out = workspaceMgr.dup(ArrayType.ACTIVATIONS, out); out = workspaceMgr.dup(ArrayType.ACTIVATIONS, out);
fwdPassTimeSteps = null; //Null -> last time step for all examples fwdPassTimeSteps = null; //Null -> last time step for all examples
@ -99,8 +98,7 @@ public class LastTimeStepVertex extends BaseGraphVertex {
//Want the index of the last non-zero entry in the mask array. //Want the index of the last non-zero entry in the mask array.
//Check a little here by using mulRowVector([0,1,2,3,...]) and argmax //Check a little here by using mulRowVector([0,1,2,3,...]) and argmax
// FIXME: int cast long maxTsLength = fwdPassShape[2];
int maxTsLength = (int) fwdPassShape[2];
INDArray row = Nd4j.linspace(0, maxTsLength - 1, maxTsLength, mask.dataType()); INDArray row = Nd4j.linspace(0, maxTsLength - 1, maxTsLength, mask.dataType());
INDArray temp = mask.mulRowVector(row); INDArray temp = mask.mulRowVector(row);
INDArray lastElementIdx = Nd4j.argMax(temp, 1); INDArray lastElementIdx = Nd4j.argMax(temp, 1);

View File

@ -346,7 +346,6 @@ public abstract class AbstractLayer<LayerConfT extends org.deeplearning4j.nn.con
@Override @Override
public int getInputMiniBatchSize() { public int getInputMiniBatchSize() {
// FIXME: int cast
return (int) input.size(0); return (int) input.size(0);
} }

View File

@ -229,7 +229,6 @@ public abstract class BaseOutputLayer<LayerConfT extends org.deeplearning4j.nn.c
*/ */
@Override @Override
public int numLabels() { public int numLabels() {
// FIXME: int cast
return (int) labels.size(1); return (int) labels.size(1);
} }

View File

@ -236,7 +236,6 @@ public class LossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.LossL
*/ */
@Override @Override
public int numLabels() { public int numLabels() {
// FIXME: int cast
return (int) labels.size(1); return (int) labels.size(1);
} }

View File

@ -86,19 +86,18 @@ public class Cnn3DLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.
INDArray delta2d = lossFunction.computeGradient(labels2d, input2d.dup(input2d.ordering()), layerConf().getActivationFn(), maskReshaped); INDArray delta2d = lossFunction.computeGradient(labels2d, input2d.dup(input2d.ordering()), layerConf().getActivationFn(), maskReshaped);
delta2d = workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, delta2d); delta2d = workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, delta2d);
// FIXME: int cast long n = input.size(0);
int n = (int)input.size(0); long d, h, w, c;
int d, h, w, c;
if(layerConf().getDataFormat() == Convolution3D.DataFormat.NDHWC){ if(layerConf().getDataFormat() == Convolution3D.DataFormat.NDHWC){
d = (int)input.size(1); d = input.size(1);
h = (int)input.size(2); h = input.size(2);
w = (int)input.size(3); w = input.size(3);
c = (int)input.size(4); c = input.size(4);
} else { } else {
d = (int)input.size(2); d = input.size(2);
h = (int)input.size(3); h = input.size(3);
w = (int)input.size(4); w = input.size(4);
c = (int)input.size(1); c = input.size(1);
} }
INDArray delta5d = ConvolutionUtils.reshape2dTo5d(layerConf().getDataFormat(), delta2d, n, d, h, w, c, workspaceMgr, ArrayType.ACTIVATION_GRAD); INDArray delta5d = ConvolutionUtils.reshape2dTo5d(layerConf().getDataFormat(), delta2d, n, d, h, w, c, workspaceMgr, ArrayType.ACTIVATION_GRAD);
@ -130,7 +129,6 @@ public class Cnn3DLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.
@Override @Override
public int numLabels() { public int numLabels() {
// FIXME: int cast
return (int) labels.size(1); return (int) labels.size(1);
} }
@ -180,10 +178,8 @@ public class Cnn3DLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.
INDArray input2d = ConvolutionUtils.reshape5dTo2d(layerConf().getDataFormat(), in, workspaceMgr, ArrayType.ACTIVATIONS); INDArray input2d = ConvolutionUtils.reshape5dTo2d(layerConf().getDataFormat(), in, workspaceMgr, ArrayType.ACTIVATIONS);
INDArray out2d = layerConf().getActivationFn().getActivation(input2d, training); INDArray out2d = layerConf().getActivationFn().getActivation(input2d, training);
// FIXME: int cast long n = input.size(0);
long d, h, w, c;
int n = (int)input.size(0);
int d, h, w, c;
if(layerConf().getDataFormat() == Convolution3D.DataFormat.NDHWC){ if(layerConf().getDataFormat() == Convolution3D.DataFormat.NDHWC){
d = (int)input.size(1); d = (int)input.size(1);
h = (int)input.size(2); h = (int)input.size(2);
@ -262,19 +258,18 @@ public class Cnn3DLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.
val newShape = input.shape().clone(); val newShape = input.shape().clone();
newShape[1] = 1; newShape[1] = 1;
// FIXME long n = input.size(0);
int n = (int)input.size(0); long d, h, w, c;
int d, h, w, c;
if(layerConf().getDataFormat() == Convolution3D.DataFormat.NDHWC){ if(layerConf().getDataFormat() == Convolution3D.DataFormat.NDHWC){
d = (int)input.size(1); d = input.size(1);
h = (int)input.size(2); h = input.size(2);
w = (int)input.size(3); w = input.size(3);
c = (int)input.size(4); c = input.size(4);
} else { } else {
d = (int)input.size(2); d = input.size(2);
h = (int)input.size(3); h = input.size(3);
w = (int)input.size(4); w = input.size(4);
c = (int)input.size(1); c = input.size(1);
} }
INDArray scoreArrayTs = ConvolutionUtils.reshape2dTo5d(layerConf().getDataFormat(), scoreArray, n, d, h, w, c, workspaceMgr, ArrayType.FF_WORKING_MEM); INDArray scoreArrayTs = ConvolutionUtils.reshape2dTo5d(layerConf().getDataFormat(), scoreArray, n, d, h, w, c, workspaceMgr, ArrayType.FF_WORKING_MEM);
INDArray summedScores = scoreArrayTs.sum(1,2,3,4); INDArray summedScores = scoreArrayTs.sum(1,2,3,4);

View File

@ -88,8 +88,7 @@ public class CnnLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.Cn
INDArray delta2d = lossFunction.computeGradient(labels2d, input2d.dup(input2d.ordering()), layerConf().getActivationFn(), maskReshaped); INDArray delta2d = lossFunction.computeGradient(labels2d, input2d.dup(input2d.ordering()), layerConf().getActivationFn(), maskReshaped);
delta2d = workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, delta2d); delta2d = workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, delta2d);
// FIXME: int cast INDArray delta4d = ConvolutionUtils.reshape2dTo4d(delta2d, input.shape(), workspaceMgr, ArrayType.ACTIVATION_GRAD);
INDArray delta4d = ConvolutionUtils.reshape2dTo4d(delta2d, ArrayUtil.toInts(input.shape()), workspaceMgr, ArrayType.ACTIVATION_GRAD);
// grab the empty gradient // grab the empty gradient
Gradient gradient = new DefaultGradient(); Gradient gradient = new DefaultGradient();
@ -119,7 +118,6 @@ public class CnnLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.Cn
@Override @Override
public int numLabels() { public int numLabels() {
// FIXME: int cast
return (int) labels.size(1); return (int) labels.size(1);
} }
@ -169,8 +167,7 @@ public class CnnLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.Cn
INDArray input2d = ConvolutionUtils.reshape4dTo2d(in, workspaceMgr, ArrayType.ACTIVATIONS); INDArray input2d = ConvolutionUtils.reshape4dTo2d(in, workspaceMgr, ArrayType.ACTIVATIONS);
INDArray out2d = layerConf().getActivationFn().getActivation(input2d, training); INDArray out2d = layerConf().getActivationFn().getActivation(input2d, training);
// FIXME: int cast return ConvolutionUtils.reshape2dTo4d(out2d, input.shape(), workspaceMgr, ArrayType.ACTIVATIONS);
return ConvolutionUtils.reshape2dTo4d(out2d, ArrayUtil.toInts(input.shape()), workspaceMgr, ArrayType.ACTIVATIONS);
} }
@Override @Override
@ -236,8 +233,7 @@ public class CnnLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.Cn
val newShape = input.shape().clone(); val newShape = input.shape().clone();
newShape[1] = 1; newShape[1] = 1;
// FIXME INDArray scoreArrayTs = ConvolutionUtils.reshape2dTo4d(scoreArray, newShape, workspaceMgr, ArrayType.FF_WORKING_MEM);
INDArray scoreArrayTs = ConvolutionUtils.reshape2dTo4d(scoreArray, ArrayUtil.toInts(newShape), workspaceMgr, ArrayType.FF_WORKING_MEM);
INDArray summedScores = scoreArrayTs.sum(1,2,3).reshape(scoreArrayTs.size(0), 1); INDArray summedScores = scoreArrayTs.sum(1,2,3).reshape(scoreArrayTs.size(0), 1);
if (fullNetRegTerm != 0.0) { if (fullNetRegTerm != 0.0) {

View File

@ -71,8 +71,7 @@ public class Convolution3DLayer extends ConvolutionLayer {
boolean isNCDHW = layerConfig.getDataFormat() == Convolution3D.DataFormat.NCDHW; boolean isNCDHW = layerConfig.getDataFormat() == Convolution3D.DataFormat.NCDHW;
// FIXME: int cast long miniBatch = input.size(0);
int miniBatch = (int) input.size(0);
int inD = (int) (isNCDHW ? input.size(2) : input.size(1)); int inD = (int) (isNCDHW ? input.size(2) : input.size(1));
int inH = (int) (isNCDHW ? input.size(3) : input.size(2)); int inH = (int) (isNCDHW ? input.size(3) : input.size(2));
int inW = (int) (isNCDHW ? input.size(4) : input.size(3)); int inW = (int) (isNCDHW ? input.size(4) : input.size(3));
@ -189,8 +188,7 @@ public class Convolution3DLayer extends ConvolutionLayer {
+ " " + layerId()); + " " + layerId());
} }
// FIXME: int cast long miniBatch = input.size(0);
int miniBatch = (int) input.size(0);
int inputChannels = (int) (isNCDHW ? input.size(1) : input.size(4)); int inputChannels = (int) (isNCDHW ? input.size(1) : input.size(4));
int inD =(int) (isNCDHW ? input.size(2) : input.size(1)); int inD =(int) (isNCDHW ? input.size(2) : input.size(1));
int inH = (int) (isNCDHW ? input.size(3) : input.size(2)); int inH = (int) (isNCDHW ? input.size(3) : input.size(2));

View File

@ -35,6 +35,7 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.convolution.Convolution; import org.nd4j.linalg.convolution.Convolution;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.exception.ND4JOpProfilerException; import org.nd4j.linalg.exception.ND4JOpProfilerException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
@ -113,13 +114,12 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
if(epsilon.dataType() != dataType) if(epsilon.dataType() != dataType)
epsilon = epsilon.castTo(dataType); epsilon = epsilon.castTo(dataType);
// FIXME: int cast long miniBatch = input.size(0);
int miniBatch = (int) input.size(0);
int inH = (int) input.size(2); int inH = (int) input.size(2);
int inW = (int) input.size(3); int inW = (int) input.size(3);
int outDepth = (int) weights.size(0); long outDepth = weights.size(0);
int inDepth = (int) weights.size(1); long inDepth = weights.size(1);
int kH = (int) weights.size(2); int kH = (int) weights.size(2);
int kW = (int) weights.size(3); int kW = (int) weights.size(3);
@ -143,7 +143,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
INDArray biasGradView = gradientViews.get(ConvolutionParamInitializer.BIAS_KEY); INDArray biasGradView = gradientViews.get(ConvolutionParamInitializer.BIAS_KEY);
INDArray weightGradView = gradientViews.get(ConvolutionParamInitializer.WEIGHT_KEY); //4d, c order. Shape: [outDepth,inDepth,kH,kW] INDArray weightGradView = gradientViews.get(ConvolutionParamInitializer.WEIGHT_KEY); //4d, c order. Shape: [outDepth,inDepth,kH,kW]
INDArray weightGradView2df = Shape INDArray weightGradView2df = Shape
.newShapeNoCopy(weightGradView, new int[] {outDepth, inDepth * kH * kW}, false).transpose(); .newShapeNoCopy(weightGradView, new long[]{outDepth, inDepth * kH * kW}, false).transpose();
@ -204,7 +204,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
//Note: due to the permute in preOut, and the fact that we essentially do a preOut.muli(epsilon), this reshape //Note: due to the permute in preOut, and the fact that we essentially do a preOut.muli(epsilon), this reshape
// should be zero-copy; only possible exception being sometimes with the "identity" activation case // should be zero-copy; only possible exception being sometimes with the "identity" activation case
INDArray delta2d = delta.reshape('c', new int[] {outDepth, miniBatch * outH * outW}); //Shape.newShapeNoCopy(delta,new int[]{outDepth,miniBatch*outH*outW},false); INDArray delta2d = delta.reshape('c', new long[] {outDepth, miniBatch * outH * outW}); //Shape.newShapeNoCopy(delta,new int[]{outDepth,miniBatch*outH*outW},false);
//Do im2col, but with order [miniB,outH,outW,depthIn,kH,kW]; but need to input [miniBatch,channels,kH,kW,outH,outW] given the current im2col implementation //Do im2col, but with order [miniB,outH,outW,depthIn,kH,kW]; but need to input [miniBatch,channels,kH,kW,outH,outW] given the current im2col implementation
//To get this: create an array of the order we want, permute it to the order required by im2col implementation, and then do im2col on that //To get this: create an array of the order we want, permute it to the order required by im2col implementation, and then do im2col on that
@ -231,7 +231,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
//Calculate epsilons for layer below, in 2d format (note: this is in 'image patch' format before col2im reduction) //Calculate epsilons for layer below, in 2d format (note: this is in 'image patch' format before col2im reduction)
//Note: cc -> f mmul here, then reshape to 6d in f order //Note: cc -> f mmul here, then reshape to 6d in f order
INDArray epsNext2d = w2d.mmul(delta2d); //TODO can we reuse im2col array instead of allocating new result array? INDArray epsNext2d = w2d.mmul(delta2d); //TODO can we reuse im2col array instead of allocating new result array?
INDArray eps6d = Shape.newShapeNoCopy(epsNext2d, new int[] {kW, kH, inDepth, outW, outH, miniBatch}, true); INDArray eps6d = Shape.newShapeNoCopy(epsNext2d, new long[] {kW, kH, inDepth, outW, outH, miniBatch}, true);
//Calculate epsilonNext by doing im2col reduction. //Calculate epsilonNext by doing im2col reduction.
//Current col2im implementation expects input with order: [miniBatch,channels,kH,kW,outH,outW] //Current col2im implementation expects input with order: [miniBatch,channels,kH,kW,outH,outW]
@ -282,7 +282,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
} }
} }
protected void validateInputDepth(int inDepth) { protected void validateInputDepth(long inDepth) {
if (input.size(1) != inDepth) { if (input.size(1) != inDepth) {
String layerName = conf.getLayer().getLayerName(); String layerName = conf.getLayer().getLayerName();
if (layerName == null) if (layerName == null)
@ -313,14 +313,13 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
INDArray input = this.input.castTo(dataType); INDArray input = this.input.castTo(dataType);
// FIXME: int cast long miniBatch = input.size(0);
int miniBatch = (int) input.size(0); long outDepth = weights.size(0);
int outDepth = (int) weights.size(0); long inDepth = weights.size(1);
int inDepth = (int) weights.size(1);
validateInputDepth(inDepth); validateInputDepth(inDepth);
int kH = (int) weights.size(2); long kH = weights.size(2);
int kW = (int) weights.size(3); long kW = weights.size(3);
int[] dilation = layerConf().getDilation(); int[] dilation = layerConf().getDilation();
int[] kernel = layerConf().getKernelSize(); int[] kernel = layerConf().getKernelSize();
@ -331,7 +330,8 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
if (convolutionMode == ConvolutionMode.Same) { if (convolutionMode == ConvolutionMode.Same) {
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation); //Also performs validation outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation); //Also performs validation
// FIXME: int cast if (input.size(2) > Integer.MAX_VALUE || input.size(3) > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) input.size(2), (int) input.size(3)}, kernel, pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) input.size(2), (int) input.size(3)}, kernel,
strides, dilation ); strides, dilation );
} else { } else {
@ -397,10 +397,12 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
INDArray col = Nd4j.createUninitialized(weights.dataType(), new long[] {miniBatch, outH, outW, inDepth, kH, kW}, 'c'); INDArray col = Nd4j.createUninitialized(weights.dataType(), new long[] {miniBatch, outH, outW, inDepth, kH, kW}, 'c');
INDArray col2 = col.permute(0, 3, 4, 5, 1, 2); INDArray col2 = col.permute(0, 3, 4, 5, 1, 2);
INDArray im2ColIn = input.castTo(col2.dataType()); //No op if already (for example) float INDArray im2ColIn = input.castTo(col2.dataType()); //No op if already (for example) float
Convolution.im2col(im2ColIn, kH, kW, strides[0], strides[1], pad[0], pad[1], dilation[0], dilation[1], if (kH > Integer.MAX_VALUE || kW > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
Convolution.im2col(im2ColIn, (int)kH, (int)kW, strides[0], strides[1], pad[0], pad[1], dilation[0], dilation[1],
convolutionMode == ConvolutionMode.Same, col2); convolutionMode == ConvolutionMode.Same, col2);
INDArray im2col2d = Shape.newShapeNoCopy(col, new int[] {miniBatch * outH * outW, inDepth * kH * kW}, false); INDArray im2col2d = Shape.newShapeNoCopy(col, new long[] {miniBatch * outH * outW, inDepth * kH * kW}, false);
//Current order of weights: [depthOut,depthIn,kH,kW], c order //Current order of weights: [depthOut,depthIn,kH,kW], c order
//Permute to give [kW,kH,depthIn,depthOut], f order //Permute to give [kW,kH,depthIn,depthOut], f order
@ -418,7 +420,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
} }
//Now, reshape to [outW,outH,miniBatch,outDepth], and permute to have correct output order: [miniBath,outDepth,outH,outW]; //Now, reshape to [outW,outH,miniBatch,outDepth], and permute to have correct output order: [miniBath,outDepth,outH,outW];
z = Shape.newShapeNoCopy(z, new int[] {outW, outH, miniBatch, outDepth}, true); z = Shape.newShapeNoCopy(z, new long[] {outW, outH, miniBatch, outDepth}, true);
z = z.permute(2, 3, 1, 0); z = z.permute(2, 3, 1, 0);
if (training && cacheMode != CacheMode.NONE && workspaceMgr.hasConfiguration(ArrayType.FF_CACHE) && workspaceMgr.isWorkspaceOpen(ArrayType.FF_CACHE)) { if (training && cacheMode != CacheMode.NONE && workspaceMgr.hasConfiguration(ArrayType.FF_CACHE) && workspaceMgr.isWorkspaceOpen(ArrayType.FF_CACHE)) {

View File

@ -171,13 +171,14 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
+ " " + layerId()); + " " + layerId());
} }
// FIXME: int cast long inDepth = weights.size(0);
int inDepth = (int) weights.size(0); long outDepth = weights.size(1);
int outDepth = (int) weights.size(1);
if (input.size(1) != inDepth && input.size(3) == inDepth) { if (input.size(1) != inDepth && input.size(3) == inDepth) {
//TODO AB 2019/10/25 this is an ugly "pseudo-NHWC support" hack that needs to be removed ASAD
//https://github.com/eclipse/deeplearning4j/issues/8315
input = input.permute(0, 3, 1, 2); input = input.permute(0, 3, 1, 2);
} else if (input.size(1) != inDepth && input.size(3) != inDepth) { } else if (input.size(1) != inDepth ) {
String layerName = conf.getLayer().getLayerName(); String layerName = conf.getLayer().getLayerName();
if (layerName == null) if (layerName == null)
layerName = "(not named)"; layerName = "(not named)";
@ -197,7 +198,6 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
int[] pad; int[] pad;
int[] outSize; int[] outSize;
if (convolutionMode == ConvolutionMode.Same) { if (convolutionMode == ConvolutionMode.Same) {
// FIXME: int cast
outSize = ConvolutionUtils.getDeconvolutionOutputSize(input, kernel, strides, null, convolutionMode, dilation); //Also performs validation outSize = ConvolutionUtils.getDeconvolutionOutputSize(input, kernel, strides, null, convolutionMode, dilation); //Also performs validation
pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) input.size(2), (int) input.size(3)}, kernel, pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) input.size(2), (int) input.size(3)}, kernel,
strides, dilation ); strides, dilation );
@ -206,8 +206,8 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
outSize = ConvolutionUtils.getDeconvolutionOutputSize(input, kernel, strides, pad, convolutionMode, dilation); //Also performs validation outSize = ConvolutionUtils.getDeconvolutionOutputSize(input, kernel, strides, pad, convolutionMode, dilation); //Also performs validation
} }
int outH = outSize[0]; long outH = outSize[0];
int outW = outSize[1]; long outW = outSize[1];
val miniBatch = input.size(0); val miniBatch = input.size(0);

View File

@ -32,6 +32,7 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
@ -75,12 +76,11 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer {
INDArray input = this.input.castTo(dataType); //No-op if correct type INDArray input = this.input.castTo(dataType); //No-op if correct type
// FIXME: int cast long miniBatch = input.size(0);
int miniBatch = (int) input.size(0);
int inH = (int)input.size(2); int inH = (int)input.size(2);
int inW = (int)input.size(3); int inW = (int)input.size(3);
int inDepth = (int) depthWiseWeights.size(2); long inDepth = depthWiseWeights.size(2);
int kH = (int) depthWiseWeights.size(0); int kH = (int) depthWiseWeights.size(0);
int kW = (int) depthWiseWeights.size(1); int kW = (int) depthWiseWeights.size(1);
@ -169,10 +169,9 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer {
INDArray input = this.input.castTo(dataType); //no-op if correct dtype INDArray input = this.input.castTo(dataType); //no-op if correct dtype
// FIXME: int cast long inDepth = depthWiseWeights.size(2);
int inDepth = (int) depthWiseWeights.size(2); long depthMultiplier = depthWiseWeights.size(3);
int depthMultiplier = (int) depthWiseWeights.size(3); long outDepth = depthMultiplier * inDepth;
int outDepth = depthMultiplier * inDepth;
if (input.size(1) != inDepth) { if (input.size(1) != inDepth) {
String layerName = conf.getLayer().getLayerName(); String layerName = conf.getLayer().getLayerName();
@ -197,7 +196,9 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer {
if (convolutionMode == ConvolutionMode.Same) { if (convolutionMode == ConvolutionMode.Same) {
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation); outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation);
// FIXME: int cast if (input.size(2) > Integer.MAX_VALUE || input.size(3) > Integer.MAX_VALUE) {
throw new ND4JArraySizeException();
}
pad = ConvolutionUtils.getSameModeTopLeftPadding( pad = ConvolutionUtils.getSameModeTopLeftPadding(
outSize, new int[]{(int) input.size(2), (int) input.size(3)}, kernel, strides, dilation); outSize, new int[]{(int) input.size(2), (int) input.size(3)}, kernel, strides, dilation);
} else { } else {
@ -205,8 +206,8 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer {
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation); outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation);
} }
int outH = outSize[0]; long outH = outSize[0];
int outW = outSize[1]; long outW = outSize[1];
val miniBatch = input.size(0); val miniBatch = input.size(0);
INDArray output = workspaceMgr.create( INDArray output = workspaceMgr.create(

View File

@ -33,6 +33,7 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
@ -90,8 +91,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer {
INDArray input = this.input.castTo(dataType); INDArray input = this.input.castTo(dataType);
// FIXME: int cast long miniBatch = input.size(0);
int miniBatch = (int) input.size(0);
int inH = (int)input.size(2); int inH = (int)input.size(2);
int inW = (int)input.size(3); int inW = (int)input.size(3);
@ -194,9 +194,8 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer {
+ " " + layerId()); + " " + layerId());
} }
// FIXME: int cast long inDepth = depthWiseWeights.size(1);
int inDepth = (int) depthWiseWeights.size(1); long outDepth = pointWiseWeights.size(0);
int outDepth = (int) pointWiseWeights.size(0);
if (input.size(1) != inDepth) { if (input.size(1) != inDepth) {
String layerName = conf.getLayer().getLayerName(); String layerName = conf.getLayer().getLayerName();
@ -220,7 +219,9 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer {
if (convolutionMode == ConvolutionMode.Same) { if (convolutionMode == ConvolutionMode.Same) {
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation); //Also performs validation outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation); //Also performs validation
// FIXME: int cast if (input.size(2) > Integer.MAX_VALUE || input.size(3) > Integer.MAX_VALUE) {
throw new ND4JArraySizeException();
}
pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) input.size(2), (int) input.size(3)}, kernel, pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) input.size(2), (int) input.size(3)}, kernel,
strides, dilation ); strides, dilation );
} else { } else {

View File

@ -75,11 +75,10 @@ public class SpaceToDepth extends AbstractLayer<org.deeplearning4j.nn.conf.layer
public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
assertInputSet(true); assertInputSet(true);
// FIXME: int cast long miniBatch = input.size(0);
int miniBatch = (int) input.size(0); long inDepth = input.size(1);
int inDepth = (int) input.size(1); long inH = input.size(2);
int inH = (int) input.size(2); long inW = input.size(3);
int inW = (int) input.size(3);
INDArray input = this.input.castTo(dataType); //No-op if already correct type INDArray input = this.input.castTo(dataType); //No-op if already correct type
@ -122,17 +121,16 @@ public class SpaceToDepth extends AbstractLayer<org.deeplearning4j.nn.conf.layer
return preOutput; return preOutput;
} }
// FIXME: int cast long miniBatch = input.size(0);
int miniBatch = (int) input.size(0); long depth = input.size(1);
int depth = (int) input.size(1); long inH = input.size(2);
int inH = (int) input.size(2); long inW = input.size(3);
int inW = (int) input.size(3);
int blockSize = getBlockSize(); int blockSize = getBlockSize();
int outH = inH / blockSize; long outH = inH / blockSize;
int outW = inW / blockSize; long outW = inW / blockSize;
int outDepth = depth * blockSize * blockSize; long outDepth = depth * blockSize * blockSize;
INDArray out = workspaceMgr.create(ArrayType.ACTIVATIONS, input.dataType(), new long[]{1, miniBatch * outDepth * outH * outW}, 'c'); INDArray out = workspaceMgr.create(ArrayType.ACTIVATIONS, input.dataType(), new long[]{1, miniBatch * outDepth * outH * outW}, 'c');
INDArray reshapedOut; INDArray reshapedOut;

View File

@ -71,9 +71,8 @@ public class Subsampling3DLayer extends AbstractLayer<org.deeplearning4j.nn.conf
boolean isNCDHW = layerConf().getDataFormat() == Convolution3D.DataFormat.NCDHW; boolean isNCDHW = layerConf().getDataFormat() == Convolution3D.DataFormat.NCDHW;
// FIXME: int cast long miniBatch = input.size(0);
int miniBatch = (int) input.size(0); long inChannels = isNCDHW ? input.size(1) : input.size(4);
int inChannels = (int) (isNCDHW ? input.size(1) : input.size(4));
int inD = (int) (isNCDHW ? input.size(2) : input.size(1)); int inD = (int) (isNCDHW ? input.size(2) : input.size(1));
int inH = (int) (isNCDHW ? input.size(3) : input.size(2)); int inH = (int) (isNCDHW ? input.size(3) : input.size(2));
int inW = (int) (isNCDHW ? input.size(4) : input.size(3)); int inW = (int) (isNCDHW ? input.size(4) : input.size(3));
@ -148,9 +147,8 @@ public class Subsampling3DLayer extends AbstractLayer<org.deeplearning4j.nn.conf
} }
} }
// FIXME: int cast long miniBatch = input.size(0);
int miniBatch = (int) input.size(0); long inChannels = isNCDHW ? input.size(1) : input.size(4);
int inChannels = (int) (isNCDHW ? input.size(1) : input.size(4));
int inD = (int) (isNCDHW ? input.size(2) : input.size(1)); int inD = (int) (isNCDHW ? input.size(2) : input.size(1));
int inH = (int) (isNCDHW ? input.size(3) : input.size(2)); int inH = (int) (isNCDHW ? input.size(3) : input.size(2));
int inW = (int) (isNCDHW ? input.size(4) : input.size(3)); int inW = (int) (isNCDHW ? input.size(4) : input.size(3));
@ -170,9 +168,9 @@ public class Subsampling3DLayer extends AbstractLayer<org.deeplearning4j.nn.conf
outSize = Convolution3DUtils.get3DOutputSize( outSize = Convolution3DUtils.get3DOutputSize(
input, kernel, strides, pad, convolutionMode, dilation, isNCDHW); input, kernel, strides, pad, convolutionMode, dilation, isNCDHW);
} }
int outD = outSize[0]; long outD = outSize[0];
int outH = outSize[1]; long outH = outSize[1];
int outW = outSize[2]; long outW = outSize[2];
String opName = layerConf().getPoolingType() == PoolingType.MAX ? "maxpool3dnew" : "avgpool3dnew"; String opName = layerConf().getPoolingType() == PoolingType.MAX ? "maxpool3dnew" : "avgpool3dnew";

View File

@ -108,9 +108,6 @@ public class SubsamplingLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l
if(epsilon.dataType() != dataType) if(epsilon.dataType() != dataType)
epsilon = epsilon.castTo(dataType); epsilon = epsilon.castTo(dataType);
// FIXME: int cast
int miniBatch = (int) input.size(0);
int inDepth = (int) input.size(1);
int inH = (int)input.size(2); int inH = (int)input.size(2);
int inW = (int)input.size(3); int inW = (int)input.size(3);
@ -158,9 +155,6 @@ public class SubsamplingLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l
//subsampling doesn't have weights and thus gradients are not calculated for this layer //subsampling doesn't have weights and thus gradients are not calculated for this layer
//only scale and reshape epsilon //only scale and reshape epsilon
// FIXME: int cast
int inputHeight = (int) input().size(-2);
int inputWidth = (int) input().size(-1);
Gradient retGradient = new DefaultGradient(); Gradient retGradient = new DefaultGradient();
@ -231,9 +225,8 @@ public class SubsamplingLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l
INDArray input = this.input.castTo(dataType); INDArray input = this.input.castTo(dataType);
// FIXME: int cast long miniBatch = input.size(0);
int miniBatch = (int) input.size(0); long inDepth = input.size(1);
int inDepth = (int) input.size(1);
int inH = (int)input.size(2); int inH = (int)input.size(2);
int inW = (int)input.size(3); int inW = (int)input.size(3);
@ -250,8 +243,8 @@ public class SubsamplingLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l
pad = layerConf().getPadding(); pad = layerConf().getPadding();
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation); //Also performs validation outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation); //Also performs validation
} }
int outH = outSize[0]; long outH = outSize[0];
int outW = outSize[1]; long outW = outSize[1];
if (helper != null && (helperCountFail == 0 || !layerConf().isCudnnAllowFallback())) { if (helper != null && (helperCountFail == 0 || !layerConf().isCudnnAllowFallback())) {
@ -278,9 +271,6 @@ public class SubsamplingLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l
} }
} }
INDArray output = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), new long[]{miniBatch, inDepth, outH, outW}, 'c'); INDArray output = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), new long[]{miniBatch, inDepth, outH, outW}, 'c');
DynamicCustomOp.DynamicCustomOpsBuilder b; DynamicCustomOp.DynamicCustomOpsBuilder b;
int extra = 0; int extra = 0;

View File

@ -65,11 +65,10 @@ public class Upsampling1D extends Upsampling2D {
INDArray originalInput = input; INDArray originalInput = input;
input = input.castTo(dataType).reshape(input.size(0), input.size(1), input.size(2), 1); input = input.castTo(dataType).reshape(input.size(0), input.size(1), input.size(2), 1);
// FIXME: int cast long miniBatch = input.size(0);
int miniBatch = (int) input.size(0); long inDepth = input.size(1);
int inDepth = (int) input.size(1); long inH = input.size(2);
int inH = (int) input.size(2); long inW = input.size(3);
int inW = (int) input.size(3);
INDArray outEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), miniBatch * inDepth * inH * inW); INDArray outEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), miniBatch * inDepth * inH * inW);

View File

@ -62,11 +62,10 @@ public class Upsampling2D extends AbstractLayer<org.deeplearning4j.nn.conf.layer
public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
assertInputSet(true); assertInputSet(true);
// FIXME: int cast long miniBatch = (int) input.size(0);
int miniBatch = (int) input.size(0); long inDepth = (int) input.size(1);
int inDepth = (int) input.size(1); long inH = (int) input.size(2);
int inH = (int) input.size(2); long inW = (int) input.size(3);
int inW = (int) input.size(3);
INDArray reshapedEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, epsilon.dataType(), new long[]{miniBatch, inDepth, inH, inW}, 'c'); INDArray reshapedEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, epsilon.dataType(), new long[]{miniBatch, inDepth, inH, inW}, 'c');
@ -106,15 +105,14 @@ public class Upsampling2D extends AbstractLayer<org.deeplearning4j.nn.conf.layer
return preOutput; return preOutput;
} }
// FIXME: int cast long miniBatch = (int) input.size(0);
int miniBatch = (int) input.size(0); long inDepth = (int) input.size(1);
int inDepth = (int) input.size(1); long inH = (int) input.size(2);
int inH = (int) input.size(2); long inW = (int) input.size(3);
int inW = (int) input.size(3);
int[] size = getSize(); int[] size = getSize();
int outH = inH * size[0]; int outH = (int)inH * size[0];
int outW = inW * size[1]; int outW = (int)inW * size[1];
INDArray reshapedOutput = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), new long[]{miniBatch, inDepth, outH, outW}, 'c'); INDArray reshapedOutput = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), new long[]{miniBatch, inDepth, outH, outW}, 'c');

View File

@ -68,22 +68,21 @@ public class Upsampling3D extends AbstractLayer<org.deeplearning4j.nn.conf.layer
assertInputSet(true); assertInputSet(true);
boolean ncdhw = layerConf().getDataFormat() == org.deeplearning4j.nn.conf.layers.Convolution3D.DataFormat.NCDHW; boolean ncdhw = layerConf().getDataFormat() == org.deeplearning4j.nn.conf.layers.Convolution3D.DataFormat.NCDHW;
// FIXME: int cast
// Assumes NCDHW order // Assumes NCDHW order
int miniBatch = (int) input.size(0); long miniBatch = input.size(0);
int inChannels, inD, inH, inW; long inChannels, inD, inH, inW;
int[] intArgs; int[] intArgs;
if(ncdhw){ if(ncdhw){
inChannels = (int) input.size(1); inChannels = input.size(1);
inD = (int) input.size(2); inD = input.size(2);
inH = (int) input.size(3); inH = input.size(3);
inW = (int) input.size(4); inW = input.size(4);
intArgs = new int[] {1}; // 1 is channels first intArgs = new int[] {1}; // 1 is channels first
} else { } else {
inD = (int) input.size(1); inD = input.size(1);
inH = (int) input.size(2); inH = input.size(2);
inW = (int) input.size(3); inW = input.size(3);
inChannels = (int) input.size(4); inChannels = input.size(4);
intArgs = new int[] {0}; // 0 is channels last intArgs = new int[] {0}; // 0 is channels last
} }
@ -134,9 +133,8 @@ public class Upsampling3D extends AbstractLayer<org.deeplearning4j.nn.conf.layer
} }
boolean ncdhw = layerConf().getDataFormat() == org.deeplearning4j.nn.conf.layers.Convolution3D.DataFormat.NCDHW; boolean ncdhw = layerConf().getDataFormat() == org.deeplearning4j.nn.conf.layers.Convolution3D.DataFormat.NCDHW;
// FIXME: int cast long miniBatch = input.size(0);
int miniBatch = (int) input.size(0); long inChannels, inD, inH, inW;
int inChannels, inD, inH, inW;
int[] intArgs; int[] intArgs;
int[] size = getSize(); int[] size = getSize();
if(ncdhw){ if(ncdhw){

View File

@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.custom.ScatterUpdate; import org.nd4j.linalg.api.ops.custom.ScatterUpdate;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@ -64,8 +65,7 @@ public class EmbeddingLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.
INDArray weightGradients = gradientViews.get(DefaultParamInitializer.WEIGHT_KEY); INDArray weightGradients = gradientViews.get(DefaultParamInitializer.WEIGHT_KEY);
weightGradients.assign(0); weightGradients.assign(0);
// FIXME: int cast long[] indexes = new long[(int) input.length()];
int[] indexes = new int[(int) input.length()];
for (int i = 0; i < indexes.length; i++) { for (int i = 0; i < indexes.length; i++) {
indexes[i] = input.getInt(i, 0); indexes[i] = input.getInt(i, 0);
} }
@ -99,7 +99,8 @@ public class EmbeddingLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.
val nIn = layerConf().getNIn(); val nIn = layerConf().getNIn();
// FIXME: int cast if (input.length() > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
int[] indexes = new int[(int) input.length()]; int[] indexes = new int[(int) input.length()];
for (int i = 0; i < indexes.length; i++) { for (int i = 0; i < indexes.length; i++) {
indexes[i] = input.getInt(i, 0); indexes[i] = input.getInt(i, 0);

View File

@ -16,21 +16,28 @@
package org.deeplearning4j.nn.layers.mkldnn; package org.deeplearning4j.nn.layers.mkldnn;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper; import org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper;
import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer;
import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm; import org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm;
import org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNormDerivative;
import org.nd4j.linalg.api.ops.impl.summarystats.Variance; import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.ArrayUtil;
import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List;
import java.util.Map; import java.util.Map;
/** /**
@ -56,33 +63,59 @@ public class MKLDNNBatchNormHelper implements BatchNormalizationHelper {
} }
@Override @Override
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, int[] shape, INDArray gamma, public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma,
INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr workspaceMgr) { INDArray beta, INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr workspaceMgr) {
//2019-02-14: Backprop disabled pending fixes. https://github.com/deeplearning4j/deeplearning4j/issues/7166 if(input.dataType() != DataType.FLOAT)
//Also no MKL-DNN implemented for backprop anyway return null; //MKL-DNN only supports float
/* /*
INDArray[] in = gamma == null ? new INDArray[]{input, mean, var, epsilon} : new INDArray[]{input, mean, var, gamma, beta, epsilon}; //TODO FIXME - AB 2019/11/01 - https://github.com/eclipse/deeplearning4j/issues/8335
List<INDArray> args = new ArrayList<>();
args.add(input);
args.add(meanCache);
args.add(varCache);
args.add(epsilon);
if(gamma != null)
args.add(gamma.reshape(gamma.length()));
if(beta != null)
args.add(beta.reshape(beta.length()));
INDArray gradAtInput = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), input.shape());
INDArray[] out = gamma == null ? new INDArray[]{gradAtInput, } DynamicCustomOp op = DynamicCustomOp.builder("batchnorm_bp")
.addInputs(args.toArray(new INDArray[0]))
BatchNormDerivative bn = BatchNormDerivative.derivativeBuilder() .addIntegerArguments(
.applyBeta(gamma != null) gamma == null ? 0 : 1, //Apply scale
.applyGamma(gamma != null) beta == null ? 0 : 1, //Apply beta
.axis(new int[]{1}) //4d: is channels: NCHW; 2d: is nIn - axis 1 in both cases 1) //Axis (NCHW)
.epsilon(eps) .addFloatingPointArguments(eps)
.inputArrays(in)
.outputArrays(new INDArray[]{out})
.build(); .build();
Nd4j.exec(bn);
*/
INDArray epsAtInput = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape());
INDArray dLdm = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, meanCache.dataType(), meanCache.shape());
INDArray dLdv = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, meanCache.dataType(), meanCache.shape());
op.setOutputArgument(0, epsAtInput);
op.setOutputArgument(1, dLdm);
op.setOutputArgument(2, dLdv);
if(dGammaView != null) {
//Both are always null/not null simultaneously
op.setOutputArgument(3, dGammaView.reshape(dGammaView.length()));
op.setOutputArgument(4, dBetaView.reshape(dBetaView.length()));
}
Nd4j.exec(op);
Gradient g = new DefaultGradient();
g.setGradientFor(BatchNormalizationParamInitializer.GAMMA, dGammaView);
g.setGradientFor(BatchNormalizationParamInitializer.BETA, dBetaView);
return new Pair<>(g, epsAtInput);
*/
return null; return null;
} }
@Override @Override
public INDArray preOutput(INDArray x, boolean training, int[] shape, INDArray gamma, INDArray beta, INDArray mean, INDArray var, public INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean, INDArray var,
double decay, double eps, LayerWorkspaceMgr workspaceMgr) { double decay, double eps, LayerWorkspaceMgr workspaceMgr) {
if(x.dataType() != DataType.FLOAT) if(x.dataType() != DataType.FLOAT)
return null; //MKL-DNN only supports float return null; //MKL-DNN only supports float

View File

@ -0,0 +1,168 @@
package org.deeplearning4j.nn.layers.mkldnn;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.recurrent.FwdPassReturn;
import org.deeplearning4j.nn.layers.recurrent.LSTMHelper;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.*;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.primitives.Pair;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
public class MKLDNNLSTMHelper implements LSTMHelper {
@Override
public boolean checkSupported(IActivation gateActivationFn, IActivation activationFn, boolean hasPeepholeConnections) {
//TODO check other activation functions for MKLDNN
return gateActivationFn instanceof ActivationSigmoid && activationFn instanceof ActivationTanH && BaseMKLDNNHelper.mklDnnEnabled();
}
@Override
public Pair<Gradient, INDArray> backpropGradient(NeuralNetConfiguration conf, IActivation gateActivationFn, INDArray input,
INDArray recurrentWeights, INDArray inputWeights, INDArray epsilon, boolean truncatedBPTT,
int tbpttBackwardLength, FwdPassReturn fwdPass, boolean forwards, String inputWeightKey,
String recurrentWeightKey, String biasWeightKey, Map<String, INDArray> gradientViews,
INDArray maskArray, boolean hasPeepholeConnections, LayerWorkspaceMgr workspaceMgr) {
//Not yet implemented/supported
return null;
}
@Override
public FwdPassReturn activate(Layer layer, NeuralNetConfiguration conf, IActivation gateActivationFn, INDArray input,
INDArray recurrentWeights, INDArray inputWeights, INDArray biases, boolean training,
INDArray prevOutputActivations, INDArray prevMemCellState, boolean forBackprop, boolean forwards,
String inputWeightKey, INDArray maskArray, boolean hasPeepholeConnections, LayerWorkspaceMgr workspaceMgr) {
/*
DL4J data format: [bS, nIn, sL] - dataFormat == 2, directionMode == 0 (forward)
Inputs:
x = [bS, nIn, sL]
Wx = [nIn, 4*nOut]
Wr = [nOut, 4*nOut]
Wp = [3*nOut] Optional peephole weights
b = [4*nOut]
seqLen = [bS]
initialOut = [bs, nOut]
initialCell = [bs, nOut]
Outputs:
out = [bS, nOut, sL]
outLast = [bs, nOut]
cellLast = [bs,nOut]
Gates order: input, forget, input modulation, output
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided
const auto hasInitH = B_ARG(2); // indicates whether initial output is provided
const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided
const auto hasPH = B_ARG(4); // indicates whether peephole connections are present
const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1}
const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
*/
INDArray b1d = biases.reshape(biases.length());
INDArray seqLen = null;
if(maskArray != null){
seqLen = BooleanIndexing.firstIndex(maskArray, Conditions.equals(0), 1); //First 0 along dimension 1 (for [mb, seqLen])
}
List<INDArray> args = new ArrayList<>();
args.add(input);
args.add(inputWeights);
args.add(recurrentWeights);
if(hasPeepholeConnections){
throw new IllegalStateException("Not yet implemented");
}
args.add(b1d);
if(seqLen != null)
args.add(seqLen);
if(prevOutputActivations != null)
args.add(prevOutputActivations);
if(prevMemCellState != null)
args.add(prevMemCellState);
IActivation a = ((LSTM)conf.getLayer()).getActivationFn();
DynamicCustomOp op = DynamicCustomOp.builder("lstmLayer")
.addInputs(args.toArray(new INDArray[0]))
.addBooleanArguments(
true, //hasBiases
seqLen != null, //hasSeqLen
prevOutputActivations != null, //hasInitH
prevMemCellState != null, //hasInitC
hasPeepholeConnections, //hasPh
true, //retFullSeq
true, //retLastH
true //retLastC
)
.addIntegerArguments(
2, //data format: 2 = [bS, nIn, sL]
0, //direction: 0 = forward
activationToArg(gateActivationFn), //Gate activation
activationToArg(a), //Cell state activation
activationToArg(a) //Output activation (same as cell in DL4J)
)
.build();
List<LongShapeDescriptor> outShapes = op.calculateOutputShape();
for(LongShapeDescriptor lsd : outShapes){
INDArray arr = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, lsd.dataType(), lsd.getShape(), lsd.getOrder());
op.addOutputArgument(arr);
}
FwdPassReturn f = new FwdPassReturn();
f.fwdPassOutput = op.getOutputArgument(0);
f.lastAct = op.getOutputArgument(1);
f.lastMemCell = op.getOutputArgument(2);
return f;
}
@Override
public Map<String, Long> helperMemoryUse() {
return Collections.emptyMap();
}
private int activationToArg(IActivation a){
//0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
if(a instanceof ActivationTanH)
return 0;
if(a instanceof ActivationReLU)
return 1;
if(a instanceof ActivationSigmoid)
return 2;
if(a instanceof ActivationIdentity)
return 3;
if(a instanceof ActivationLReLU)
return 4;
if(a instanceof ActivationThresholdedReLU)
return 5;
if(a instanceof ActivationHardSigmoid)
return 7;
if(a instanceof ActivationELU)
return 8;
if(a instanceof ActivationSoftSign)
return 9;
if(a instanceof ActivationSoftPlus)
return 10;
throw new IllegalStateException("Unknown or not supported activation function: " + a);
}
}

View File

@ -118,6 +118,7 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
INDArray globalVar = params.get(BatchNormalizationParamInitializer.GLOBAL_VAR); //One of log10std will be null depending on config INDArray globalVar = params.get(BatchNormalizationParamInitializer.GLOBAL_VAR); //One of log10std will be null depending on config
INDArray globalLog10Std = params.get(BatchNormalizationParamInitializer.GLOBAL_LOG_STD); INDArray globalLog10Std = params.get(BatchNormalizationParamInitializer.GLOBAL_LOG_STD);
INDArray gamma = null; INDArray gamma = null;
INDArray beta = null;
INDArray dGammaView; INDArray dGammaView;
INDArray dBetaView; INDArray dBetaView;
INDArray dGlobalMeanView = gradientViews.get(BatchNormalizationParamInitializer.GLOBAL_MEAN); INDArray dGlobalMeanView = gradientViews.get(BatchNormalizationParamInitializer.GLOBAL_MEAN);
@ -129,6 +130,7 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
dBetaView = Nd4j.createUninitialized(dataType, tempShape, 'c'); dBetaView = Nd4j.createUninitialized(dataType, tempShape, 'c');
} else { } else {
gamma = getParam(BatchNormalizationParamInitializer.GAMMA); gamma = getParam(BatchNormalizationParamInitializer.GAMMA);
beta = getParam(BatchNormalizationParamInitializer.BETA);
dGammaView = gradientViews.get(BatchNormalizationParamInitializer.GAMMA); dGammaView = gradientViews.get(BatchNormalizationParamInitializer.GAMMA);
dBetaView = gradientViews.get(BatchNormalizationParamInitializer.BETA); dBetaView = gradientViews.get(BatchNormalizationParamInitializer.BETA);
} }
@ -152,15 +154,14 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
eps = epsilon; eps = epsilon;
} }
// FIXME: int cast
Pair<Gradient,INDArray> ret = null; Pair<Gradient,INDArray> ret = null;
try { try {
ret = helper.backpropGradient(in, eps, ArrayUtil.toInts(shape), gamma, dGammaView, dBetaView, ret = helper.backpropGradient(in, eps, shape, gamma, beta, dGammaView, dBetaView,
layerConf.getEps(), workspaceMgr); layerConf.getEps(), workspaceMgr);
} catch (ND4JOpProfilerException e){ } catch (ND4JOpProfilerException e){
throw e; //NaN panic etc for debugging throw e; //NaN panic etc for debugging
} catch (Throwable t){ } catch (Throwable t){
if(t.getMessage().contains("Failed to allocate")){ if(t.getMessage() != null && t.getMessage().contains("Failed to allocate")){
//This is a memory exception - don't fallback to built-in implementation //This is a memory exception - don't fallback to built-in implementation
throw t; throw t;
} }
@ -438,7 +439,6 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
//Note that cudnn does not support dense (2d) batch norm case as of v7.1 //Note that cudnn does not support dense (2d) batch norm case as of v7.1
double decay = layerConf.getDecay(); double decay = layerConf.getDecay();
// FIXME: int cast
INDArray ret = null; INDArray ret = null;
try { try {
if(globalVarView == null){ if(globalVarView == null){
@ -448,12 +448,12 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
globalVarView.muli(globalVarView); globalVarView.muli(globalVarView);
} }
ret = helper.preOutput(in, training == TrainingMode.TRAIN, ArrayUtil.toInts(shape), gamma, beta, globalMeanView, ret = helper.preOutput(in, training == TrainingMode.TRAIN, shape, gamma, beta, globalMeanView,
globalVarView, decay, layerConf.getEps(), workspaceMgr); globalVarView, decay, layerConf.getEps(), workspaceMgr);
} catch (ND4JOpProfilerException e){ } catch (ND4JOpProfilerException e){
throw e; //NaN panic etc for debugging throw e; //NaN panic etc for debugging
} catch (Throwable t) { } catch (Throwable t) {
if(t.getMessage().contains("Failed to allocate")){ if(t.getMessage() != null && t.getMessage().contains("Failed to allocate")){
//This is a memory exception - don't fallback to built-in implementation //This is a memory exception - don't fallback to built-in implementation
throw t; throw t;
} }

View File

@ -31,10 +31,10 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
public interface BatchNormalizationHelper extends LayerHelper { public interface BatchNormalizationHelper extends LayerHelper {
boolean checkSupported(double eps, boolean fixedGammaBeta); boolean checkSupported(double eps, boolean fixedGammaBeta);
Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, int[] shape, INDArray gamma, Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, INDArray beta,
INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr workspaceMgr); INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr workspaceMgr);
INDArray preOutput(INDArray x, boolean training, int[] shape, INDArray gamma, INDArray beta, INDArray mean, INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean,
INDArray var, double decay, double eps, LayerWorkspaceMgr workspaceMgr); INDArray var, double decay, double eps, LayerWorkspaceMgr workspaceMgr);
INDArray getMeanCache(DataType dataType); INDArray getMeanCache(DataType dataType);

View File

@ -144,7 +144,7 @@ public class LocalResponseNormalization
} catch (ND4JOpProfilerException e){ } catch (ND4JOpProfilerException e){
throw e; //NaN panic etc for debugging throw e; //NaN panic etc for debugging
} catch (Throwable t){ } catch (Throwable t){
if(t.getMessage().contains("Failed to allocate")){ if(t.getMessage() != null && t.getMessage().contains("Failed to allocate")){
//This is a memory exception - don't fallback to built-in implementation //This is a memory exception - don't fallback to built-in implementation
throw t; throw t;
} }
@ -211,7 +211,7 @@ public class LocalResponseNormalization
} catch (ND4JOpProfilerException e){ } catch (ND4JOpProfilerException e){
throw e; //NaN panic etc for debugging throw e; //NaN panic etc for debugging
} catch (Throwable t){ } catch (Throwable t){
if(t.getMessage().contains("Failed to allocate")){ if(t.getMessage() != null && t.getMessage().contains("Failed to allocate")){
//This is a memory exception - don't fallback to built-in implementation //This is a memory exception - don't fallback to built-in implementation
throw t; throw t;
} }

View File

@ -114,10 +114,9 @@ public class Yolo2OutputLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l
double lambdaCoord = layerConf().getLambdaCoord(); double lambdaCoord = layerConf().getLambdaCoord();
double lambdaNoObj = layerConf().getLambdaNoObj(); double lambdaNoObj = layerConf().getLambdaNoObj();
// FIXME: int cast long mb = input.size(0);
int mb = (int) input.size(0); long h = input.size(2);
int h = (int) input.size(2); long w = input.size(3);
int w = (int) input.size(3);
int b = (int) layerConf().getBoundingBoxes().size(0); int b = (int) layerConf().getBoundingBoxes().size(0);
int c = (int) labels.size(1)-4; int c = (int) labels.size(1)-4;
@ -243,12 +242,12 @@ public class Yolo2OutputLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l
//Class prediction loss //Class prediction loss
INDArray classPredictionsPreSoftmax2d = inputClassesPreSoftmax.permute(0,1,3,4,2) //[minibatch, b, c, h, w] To [mb, b, h, w, c] INDArray classPredictionsPreSoftmax2d = inputClassesPreSoftmax.permute(0,1,3,4,2) //[minibatch, b, c, h, w] To [mb, b, h, w, c]
.dup('c').reshape('c', new int[]{mb*b*h*w, c}); .dup('c').reshape('c', new long[]{mb*b*h*w, c});
INDArray classLabelsBroadcast = Nd4j.createUninitialized(input.dataType(), new long[]{mb, b, c, h, w}, 'c'); INDArray classLabelsBroadcast = Nd4j.createUninitialized(input.dataType(), new long[]{mb, b, c, h, w}, 'c');
for(int i=0; i<b; i++ ){ for(int i=0; i<b; i++ ){
classLabelsBroadcast.get(all(), point(i), all(), all(), all()).assign(classLabels); //[mb, c, h, w] to [mb, b, c, h, w] classLabelsBroadcast.get(all(), point(i), all(), all(), all()).assign(classLabels); //[mb, c, h, w] to [mb, b, c, h, w]
} }
INDArray classLabels2d = classLabelsBroadcast.permute(0,1,3,4,2).dup('c').reshape('c', new int[]{mb*b*h*w, c}); INDArray classLabels2d = classLabelsBroadcast.permute(0,1,3,4,2).dup('c').reshape('c', new long[]{mb*b*h*w, c});
//Calculate the loss: //Calculate the loss:
ILossFunction lossConfidence = new LossL2(); ILossFunction lossConfidence = new LossL2();
@ -297,7 +296,7 @@ public class Yolo2OutputLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l
// ----- Gradient Calculation (specifically: return dL/dIn ----- // ----- Gradient Calculation (specifically: return dL/dIn -----
INDArray epsOut = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape(), 'c'); INDArray epsOut = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape(), 'c');
INDArray epsOut5 = Shape.newShapeNoCopy(epsOut, new int[]{mb, b, 5+c, h, w}, false); INDArray epsOut5 = Shape.newShapeNoCopy(epsOut, new long[]{mb, b, 5+c, h, w}, false);
INDArray epsClassPredictions = epsOut5.get(all(), all(), interval(5, 5+c), all(), all()); //Shape: [mb, b, 5+c, h, w] INDArray epsClassPredictions = epsOut5.get(all(), all(), interval(5, 5+c), all(), all()); //Shape: [mb, b, 5+c, h, w]
INDArray epsXY = epsOut5.get(all(), all(), interval(0,2), all(), all()); INDArray epsXY = epsOut5.get(all(), all(), interval(0,2), all(), all());
INDArray epsWH = epsOut5.get(all(), all(), interval(2,4), all(), all()); INDArray epsWH = epsOut5.get(all(), all(), interval(2,4), all(), all());
@ -426,16 +425,16 @@ public class Yolo2OutputLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l
* @return IOU and gradients * @return IOU and gradients
*/ */
private static IOURet calculateIOULabelPredicted(INDArray labelTL, INDArray labelBR, INDArray predictedWH, INDArray predictedXYinGridBox, INDArray objectPresentMask, INDArray objectPresentMaskBool){ private static IOURet calculateIOULabelPredicted(INDArray labelTL, INDArray labelBR, INDArray predictedWH, INDArray predictedXYinGridBox, INDArray objectPresentMask, INDArray objectPresentMaskBool){
// FIXME: int cast
int mb = (int) labelTL.size(0); long mb = labelTL.size(0);
int h = (int) labelTL.size(2); long h = labelTL.size(2);
int w = (int) labelTL.size(3); long w = labelTL.size(3);
int b = (int) predictedWH.size(1); long b = predictedWH.size(1);
INDArray labelWH = labelBR.sub(labelTL); //4d [mb, 2, H, W], label W/H in terms of number of grid boxes INDArray labelWH = labelBR.sub(labelTL); //4d [mb, 2, H, W], label W/H in terms of number of grid boxes
int gridH = (int) labelTL.size(2); long gridH = labelTL.size(2);
int gridW = (int) labelTL.size(3); long gridW = labelTL.size(3);
//Add grid positions to the predicted XY values (to get predicted XY in terms of grid cell units in image, //Add grid positions to the predicted XY values (to get predicted XY in terms of grid cell units in image,
// from (0 to 1 in grid cell) format) // from (0 to 1 in grid cell) format)
INDArray linspaceX = Nd4j.linspace(0, gridW-1, gridW, predictedWH.dataType()); INDArray linspaceX = Nd4j.linspace(0, gridW-1, gridW, predictedWH.dataType());

View File

@ -45,12 +45,11 @@ public class YoloUtils {
} }
public static INDArray activate(@NonNull INDArray boundingBoxPriors, @NonNull INDArray input, LayerWorkspaceMgr layerWorkspaceMgr){ public static INDArray activate(@NonNull INDArray boundingBoxPriors, @NonNull INDArray input, LayerWorkspaceMgr layerWorkspaceMgr){
// FIXME: int cast long mb = input.size(0);
int mb = (int) input.size(0); long h = input.size(2);
int h = (int) input.size(2); long w = input.size(3);
int w = (int) input.size(3); long b = boundingBoxPriors.size(0);
int b = (int) boundingBoxPriors.size(0); long c = input.size(1)/b-5; //input.size(1) == b * (5 + C) -> C = (input.size(1)/b) - 5
int c = (int) (input.size(1)/b)-5; //input.size(1) == b * (5 + C) -> C = (input.size(1)/b) - 5
INDArray output = layerWorkspaceMgr.create(ArrayType.ACTIVATIONS, input.dataType(), input.shape(), 'c'); INDArray output = layerWorkspaceMgr.create(ArrayType.ACTIVATIONS, input.dataType(), input.shape(), 'c');
INDArray output5 = output.reshape('c', mb, b, 5+c, h, w); INDArray output5 = output.reshape('c', mb, b, 5+c, h, w);
@ -77,7 +76,7 @@ public class YoloUtils {
//TODO OPTIMIZE? //TODO OPTIMIZE?
INDArray inputClassesPreSoftmax = input5.get(all(), all(), interval(5, 5+c), all(), all()); //Shape: [minibatch, C, H, W] INDArray inputClassesPreSoftmax = input5.get(all(), all(), interval(5, 5+c), all(), all()); //Shape: [minibatch, C, H, W]
INDArray classPredictionsPreSoftmax2d = inputClassesPreSoftmax.permute(0,1,3,4,2) //[minibatch, b, c, h, w] To [mb, b, h, w, c] INDArray classPredictionsPreSoftmax2d = inputClassesPreSoftmax.permute(0,1,3,4,2) //[minibatch, b, c, h, w] To [mb, b, h, w, c]
.dup('c').reshape('c', new int[]{mb*b*h*w, c}); .dup('c').reshape('c', new long[]{mb*b*h*w, c});
Transforms.softmax(classPredictionsPreSoftmax2d, false); Transforms.softmax(classPredictionsPreSoftmax2d, false);
INDArray postSoftmax5d = classPredictionsPreSoftmax2d.reshape('c', mb, b, h, w, c ).permute(0, 1, 4, 2, 3); INDArray postSoftmax5d = classPredictionsPreSoftmax2d.reshape('c', mb, b, h, w, c ).permute(0, 1, 4, 2, 3);
@ -173,13 +172,12 @@ public class YoloUtils {
throw new IllegalStateException("Invalid confidence threshold: must be in range [0,1]. Got: " + confThreshold); throw new IllegalStateException("Invalid confidence threshold: must be in range [0,1]. Got: " + confThreshold);
} }
// FIXME: int cast
//Activations format: [mb, 5b+c, h, w] //Activations format: [mb, 5b+c, h, w]
int mb = (int) networkOutput.size(0); long mb = networkOutput.size(0);
int h = (int) networkOutput.size(2); long h = networkOutput.size(2);
int w = (int) networkOutput.size(3); long w = networkOutput.size(3);
int b = (int) boundingBoxPriors.size(0); long b = boundingBoxPriors.size(0);
int c = (int) (networkOutput.size(1)/b)-5; //input.size(1) == b * (5 + C) -> C = (input.size(1)/b) - 5 long c = (networkOutput.size(1)/b)-5; //input.size(1) == b * (5 + C) -> C = (input.size(1)/b) - 5
//Reshape from [minibatch, B*(5+C), H, W] to [minibatch, B, 5+C, H, W] to [minibatch, B, 5, H, W] //Reshape from [minibatch, B*(5+C), H, W] to [minibatch, B, 5+C, H, W] to [minibatch, B, 5, H, W]
INDArray output5 = networkOutput.dup('c').reshape(mb, b, 5+c, h, w); INDArray output5 = networkOutput.dup('c').reshape(mb, b, 5+c, h, w);

View File

@ -22,6 +22,8 @@ import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.LayerHelper; import org.deeplearning4j.nn.layers.LayerHelper;
import org.deeplearning4j.nn.layers.mkldnn.BaseMKLDNNHelper;
import org.deeplearning4j.nn.layers.mkldnn.MKLDNNLSTMHelper;
import org.deeplearning4j.nn.params.LSTMParamInitializer; import org.deeplearning4j.nn.params.LSTMParamInitializer;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
@ -73,6 +75,16 @@ public class LSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.L
} }
} }
} }
/*
//Disabled pending: https://github.com/eclipse/deeplearning4j/issues/8331
else if ("CPU".equalsIgnoreCase(backend) && BaseMKLDNNHelper.mklDnnEnabled()){
helper = new MKLDNNLSTMHelper();
log.debug("MKLDNNLSTMHelper successfully initialized");
if (!helper.checkSupported(layerConf().getGateActivationFn(), layerConf().getActivationFn(), false)) {
helper = null;
}
}
*/
} }
@Override @Override

View File

@ -40,6 +40,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
import org.nd4j.linalg.api.ops.impl.transforms.same.TimesOneMinus; import org.nd4j.linalg.api.ops.impl.transforms.same.TimesOneMinus;
import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.exception.ND4JOpProfilerException; import org.nd4j.linalg.exception.ND4JOpProfilerException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
@ -113,7 +114,9 @@ public class LSTMHelpers {
input = input.castTo(inputWeights.dataType()); //No-op if already correct dtype input = input.castTo(inputWeights.dataType()); //No-op if already correct dtype
// FIXME if ((!is2dInput && (input.size(2) > Integer.MAX_VALUE)) ||
recurrentWeights.size(0) > Integer.MAX_VALUE || input.size(0) > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
int timeSeriesLength = (int) (is2dInput ? 1 : input.size(2)); int timeSeriesLength = (int) (is2dInput ? 1 : input.size(2));
int hiddenLayerSize = (int) recurrentWeights.size(0); int hiddenLayerSize = (int) recurrentWeights.size(0);
int miniBatchSize = (int) input.size(0); int miniBatchSize = (int) input.size(0);
@ -550,7 +553,8 @@ public class LSTMHelpers {
for (long iTimeIndex = timeSeriesLength - 1; iTimeIndex >= endIdx; iTimeIndex--) { for (long iTimeIndex = timeSeriesLength - 1; iTimeIndex >= endIdx; iTimeIndex--) {
try(MemoryWorkspace ws = workspaceMgr.notifyScopeEntered(ArrayType.RNN_BP_LOOP_WORKING_MEM)) { try(MemoryWorkspace ws = workspaceMgr.notifyScopeEntered(ArrayType.RNN_BP_LOOP_WORKING_MEM)) {
// FIXME: int cast if (iTimeIndex > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
int time = (int) iTimeIndex; int time = (int) iTimeIndex;
int inext = 1; int inext = 1;
@ -574,8 +578,6 @@ public class LSTMHelpers {
(iTimeIndex == 0 ? fwdPass.prevAct : fwdPass.fwdPassOutputAsArrays[(int) (time - inext)]); (iTimeIndex == 0 ? fwdPass.prevAct : fwdPass.fwdPassOutputAsArrays[(int) (time - inext)]);
INDArray currMemCellState = fwdPass.memCellState[(int) time]; INDArray currMemCellState = fwdPass.memCellState[(int) time];
// FIXME: int cast
//LSTM unit output errors (dL/d(a_out)); not to be confused with \delta=dL/d(z_out) //LSTM unit output errors (dL/d(a_out)); not to be confused with \delta=dL/d(z_out)
INDArray epsilonSlice = (is2dInput ? epsilon : epsilon.tensorAlongDimension((int) time, 1, 0)); //(w^{L+1}*(delta^{(L+1)t})^T)^T or equiv. INDArray epsilonSlice = (is2dInput ? epsilon : epsilon.tensorAlongDimension((int) time, 1, 0)); //(w^{L+1}*(delta^{(L+1)t})^T)^T or equiv.

View File

@ -89,8 +89,7 @@ public class RnnLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.Rn
ILossFunction lossFunction = layerConf().getLossFn(); ILossFunction lossFunction = layerConf().getLossFn();
INDArray delta2d = lossFunction.computeGradient(labels2d, input2d.dup(input2d.ordering()), layerConf().getActivationFn(), maskReshaped); INDArray delta2d = lossFunction.computeGradient(labels2d, input2d.dup(input2d.ordering()), layerConf().getActivationFn(), maskReshaped);
// FIXME: int cast INDArray delta3d = TimeSeriesUtils.reshape2dTo3d(delta2d, input.size(0), workspaceMgr, ArrayType.ACTIVATION_GRAD);
INDArray delta3d = TimeSeriesUtils.reshape2dTo3d(delta2d, (int) input.size(0), workspaceMgr, ArrayType.ACTIVATION_GRAD);
// grab the empty gradient // grab the empty gradient
Gradient gradient = new DefaultGradient(); Gradient gradient = new DefaultGradient();
@ -119,7 +118,6 @@ public class RnnLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.Rn
@Override @Override
public int numLabels() { public int numLabels() {
// FIXME: int cast
return (int) labels.size(1); return (int) labels.size(1);
} }
@ -167,7 +165,7 @@ public class RnnLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.Rn
INDArray as2d = TimeSeriesUtils.reshape3dTo2d(input); INDArray as2d = TimeSeriesUtils.reshape3dTo2d(input);
INDArray out2d = layerConf().getActivationFn().getActivation(workspaceMgr.dup(ArrayType.ACTIVATIONS, as2d, as2d.ordering()), training); INDArray out2d = layerConf().getActivationFn().getActivation(workspaceMgr.dup(ArrayType.ACTIVATIONS, as2d, as2d.ordering()), training);
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, TimeSeriesUtils.reshape2dTo3d(out2d, (int)input.size(0), workspaceMgr, ArrayType.ACTIVATIONS)); return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, TimeSeriesUtils.reshape2dTo3d(out2d, input.size(0), workspaceMgr, ArrayType.ACTIVATIONS));
} }
@Override @Override
@ -254,7 +252,6 @@ public class RnnLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.Rn
//scoreArray: shape [minibatch*timeSeriesLength, 1] //scoreArray: shape [minibatch*timeSeriesLength, 1]
//Reshape it to [minibatch, timeSeriesLength] then sum over time step //Reshape it to [minibatch, timeSeriesLength] then sum over time step
// FIXME: int cast
INDArray scoreArrayTs = TimeSeriesUtils.reshapeVectorToTimeSeriesMask(scoreArray, (int)input.size(0)); INDArray scoreArrayTs = TimeSeriesUtils.reshapeVectorToTimeSeriesMask(scoreArray, (int)input.size(0));
INDArray summedScores = scoreArrayTs.sum(1); INDArray summedScores = scoreArrayTs.sum(1);

View File

@ -70,8 +70,7 @@ public class RnnOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.l
this.input = inputTemp; this.input = inputTemp;
INDArray epsilon2d = gradAndEpsilonNext.getSecond(); INDArray epsilon2d = gradAndEpsilonNext.getSecond();
// FIXME: int cast INDArray epsilon3d = TimeSeriesUtils.reshape2dTo3d(epsilon2d, input.size(0), workspaceMgr, ArrayType.ACTIVATION_GRAD);
INDArray epsilon3d = TimeSeriesUtils.reshape2dTo3d(epsilon2d, (int) input.size(0), workspaceMgr, ArrayType.ACTIVATION_GRAD);
weightNoiseParams.clear(); weightNoiseParams.clear();
@ -145,8 +144,7 @@ public class RnnOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.l
} }
} }
// FIXME: int cast return TimeSeriesUtils.reshape2dTo3d(act2d, input.size(0), workspaceMgr, ArrayType.ACTIVATIONS);
return TimeSeriesUtils.reshape2dTo3d(act2d, (int) input.size(0), workspaceMgr, ArrayType.ACTIVATIONS);
} }
@Override @Override
@ -205,7 +203,6 @@ public class RnnOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.l
//scoreArray: shape [minibatch*timeSeriesLength, 1] //scoreArray: shape [minibatch*timeSeriesLength, 1]
//Reshape it to [minibatch, timeSeriesLength] then sum over time step //Reshape it to [minibatch, timeSeriesLength] then sum over time step
// FIXME: int cast
INDArray scoreArrayTs = TimeSeriesUtils.reshapeVectorToTimeSeriesMask(scoreArray, (int)input.size(0)); INDArray scoreArrayTs = TimeSeriesUtils.reshapeVectorToTimeSeriesMask(scoreArray, (int)input.size(0));
INDArray summedScores = scoreArrayTs.sum(true, 1); INDArray summedScores = scoreArrayTs.sum(true, 1);

View File

@ -0,0 +1,68 @@
package org.deeplearning4j.nn.layers.samediff;
import org.nd4j.autodiff.samediff.internal.memory.AbstractMemoryMgr;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.factory.Nd4j;
/**
* A SameDiff {@link org.nd4j.autodiff.samediff.internal.SessionMemMgr} that uses DL4J workspaces for memory management.
* Any op outputs are allocated in the output workspace if they are returned to the layer; otherwise they are placed in
* the DL4J working memory workspace
*
* @author Alex Black
*/
public class DL4JSameDiffMemoryMgr extends AbstractMemoryMgr {
private final String workingMemoryWs;
private final String outputWs;
private final WorkspaceConfiguration confWorking;
private final WorkspaceConfiguration confOutput;
//Note: if the working memory or output workspace names are null -> detached memory
public DL4JSameDiffMemoryMgr(String workingMemoryWs, String outputWs, WorkspaceConfiguration confWorking,
WorkspaceConfiguration confOutput){
this.workingMemoryWs = workingMemoryWs;
this.outputWs = outputWs;
this.confWorking = confWorking;
this.confOutput = confOutput;
}
@Override
public INDArray allocate(boolean detached, DataType dataType, long... shape) {
String wsName = detached ? outputWs : workingMemoryWs;
WorkspaceConfiguration wsConf = detached ? confOutput : confWorking;
if(wsName == null){
//Scoped out
INDArray ret = Nd4j.createUninitializedDetached(dataType, shape);
Preconditions.checkState(!ret.isAttached(), "Returned array should be detached");
return ret;
} else {
MemoryWorkspace ws = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(wsConf, wsName);
try (MemoryWorkspace mw = ws.notifyScopeBorrowed()) {
return Nd4j.createUninitialized(dataType, shape);
}
}
}
@Override
public INDArray allocate(boolean detached, LongShapeDescriptor descriptor) {
return allocate(detached, descriptor.dataType(), descriptor.getShape());
}
@Override
public void release(INDArray array) {
//No-op - DL4J workspaces handles this
}
@Override
public void close() {
//No-op - DL4J workspaces handles this
}
}

View File

@ -31,9 +31,12 @@ import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.InferenceSession;
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -98,6 +101,7 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
if (sameDiff == null) { if (sameDiff == null) {
doInit(); doInit();
} }
}
Map<String,INDArray> phMap = new HashMap<>(); Map<String,INDArray> phMap = new HashMap<>();
config.validateInput(inputs); config.validateInput(inputs);
@ -112,6 +116,25 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
} }
} }
//Configure memory management for SameDiff instance - use DL4J workspaces
String wsNameWorking = workspaceMgr.getWorkspaceName(ArrayType.FF_WORKING_MEM);
String wsNameOutput = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATIONS);
WorkspaceConfiguration confWorking = workspaceMgr.getConfiguration(ArrayType.FF_WORKING_MEM);
WorkspaceConfiguration confOutput = workspaceMgr.getConfiguration(ArrayType.ACTIVATIONS);
boolean actScopedOut = workspaceMgr.isScopedOut(ArrayType.ACTIVATIONS);
Preconditions.checkState(actScopedOut || wsNameOutput != null, "Activations must have a workspace or must be scoped out");
SessionMemMgr mmgr = new DL4JSameDiffMemoryMgr(wsNameWorking, wsNameOutput, confWorking, confOutput);
InferenceSession is = sameDiff.getSessions().get(Thread.currentThread().getId());
if(is == null){
is = new InferenceSession(sameDiff);
sameDiff.getSessions().put(Thread.currentThread().getId(), is);
}
is.setMmgr(mmgr);
if(paramTable != null && paramTable.size() > 0) { if(paramTable != null && paramTable.size() > 0) {
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration //Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO Find a more efficient solution for this //TODO Find a more efficient solution for this
@ -122,22 +145,29 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
} }
INDArray result = sameDiff.outputSingle(phMap, outputKey); INDArray result = sameDiff.outputSingle(phMap, outputKey);
//Edge case: "vertex" is just an identity activation, for example
//TODO there may be a cleaner way to do this...
if(!actScopedOut && !result.data().getParentWorkspace().getId().equals(wsNameOutput)){
result = workspaceMgr.dup(ArrayType.ACTIVATIONS, result);
} else if(actScopedOut && result.isAttached()){
result = result.detach();
}
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere //Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
sameDiff.clearPlaceholders(true); sameDiff.clearPlaceholders(true);
sameDiff.clearOpInputs(); sameDiff.clearOpInputs();
return workspaceMgr.dup(ArrayType.ACTIVATIONS, result); return workspaceMgr.dup(ArrayType.ACTIVATIONS, result);
} }
}
@Override @Override
public Pair<Gradient, INDArray[]> doBackward(boolean tbptt, LayerWorkspaceMgr workspaceMgr) { public Pair<Gradient, INDArray[]> doBackward(boolean tbptt, LayerWorkspaceMgr workspaceMgr) {
Gradient g = new DefaultGradient(); Gradient g = new DefaultGradient();
INDArray[] dLdIns;
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
if (sameDiff == null) { if (sameDiff == null) {
doInit(); doInit();
} }
}
List<String> inputNames = config.getVertexParams().getInputs(); List<String> inputNames = config.getVertexParams().getInputs();
if(!sameDiff.hasGradientFunction()) { if(!sameDiff.hasGradientFunction()) {
@ -146,6 +176,24 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
sameDiff.createGradFunction(inArr); sameDiff.createGradFunction(inArr);
} }
config.validateInput(inputs); config.validateInput(inputs);
//Configure memory management for SameDiff instance - use DL4J workspaces
Map<Long,InferenceSession> sessionMap = sameDiff.getFunction("grad").getSessions();
if(!sessionMap.containsKey(Thread.currentThread().getId())){
sessionMap.put(Thread.currentThread().getId(), new InferenceSession(sameDiff.getFunction("grad")));
}
String wsNameWorking = workspaceMgr.getWorkspaceName(ArrayType.BP_WORKING_MEM);
String wsNameActGrad = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATION_GRAD);
WorkspaceConfiguration confWorking = workspaceMgr.getConfiguration(ArrayType.BP_WORKING_MEM);
WorkspaceConfiguration confOutput = workspaceMgr.getConfiguration(ArrayType.ACTIVATION_GRAD);
boolean actGradScopedOut = workspaceMgr.isScopedOut(ArrayType.ACTIVATION_GRAD);
Preconditions.checkState(actGradScopedOut || wsNameActGrad != null, "Activation gradients must have a workspace or be scoped out");
SessionMemMgr mmgr = new DL4JSameDiffMemoryMgr(wsNameWorking, wsNameActGrad, confWorking, confOutput);
sessionMap.get(Thread.currentThread().getId()).setMmgr(mmgr);
Map<String,INDArray> phMap = new HashMap<>(); Map<String,INDArray> phMap = new HashMap<>();
List<String> inputs = config.getVertexParams().getInputs(); List<String> inputs = config.getVertexParams().getInputs();
int i=0; int i=0;
@ -167,41 +215,43 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration //Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO Find a more efficient solution for this //TODO Find a more efficient solution for this
List<String> required = new ArrayList<>(inputNames.size()); //Ensure that the input placeholder gradients are calculated
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) { for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue(); INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey())); sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
} }
List<String> required = new ArrayList<>(inputNames.size()); //Ensure that the input placeholder gradients are calculated required.addAll(paramTable.keySet());
for(String s : inputNames){ required.addAll(inputNames);
required.add(sameDiff.getVariable(s).gradient().getVarName());
} Map<String,INDArray> gradsMap = sameDiff.calculateGradients(phMap, required);
sameDiff.execBackwards(phMap, required);
for(String s : paramTable.keySet() ){ for(String s : paramTable.keySet() ){
INDArray sdGrad = sameDiff.grad(s).getArr(); INDArray sdGrad = gradsMap.get(s);
INDArray dl4jGrad = gradTable.get(s); INDArray dl4jGrad = gradTable.get(s);
dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS
g.gradientForVariable().put(s, dl4jGrad); g.gradientForVariable().put(s, dl4jGrad);
} }
dLdIns = new INDArray[inputs.size()]; INDArray[] dLdIns = new INDArray[inputs.size()];
String fnName = fn.getGradPlaceholderName(); String fnName = fn.getGradPlaceholderName();
for(int j=0; j<inputs.size(); j++ ){ for(int j=0; j<inputs.size(); j++ ){
String name = inputs.get(j); String name = inputs.get(j);
dLdIns[j] = sameDiff.grad(name).getArr(); dLdIns[j] = sameDiff.grad(name).getArr();
String gradName = sameDiff.grad(inputNames.get(j)).getVarName(); String gradName = sameDiff.grad(inputNames.get(j)).name();
if(dLdIns[j] == null && fnName.equals(gradName)){ if(dLdIns[j] == null && fnName.equals(gradName)){
//Edge case with lambda vertices like identity: SameDiff doesn't store the placeholders //Edge case with lambda vertices like identity: SameDiff doesn't store the placeholders
// So, this getArr() can be trying to get placeholder from SameDiff instance, when it's available here // So, this getArr() can be trying to get placeholder from SameDiff instance, when it's available here
dLdIns[j] = epsilon; dLdIns[j] = epsilon;
} }
}
}
//TODO optimize //Edge case: "vertex" is just an identity activation, for example
for( int i=0; i<dLdIns.length; i++ ){ //TODO there may be a cleaner way to do this...
dLdIns[i] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIns[i]); if(!actGradScopedOut && !dLdIns[j].data().getParentWorkspace().getId().equals(wsNameActGrad)){
dLdIns[j] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIns[j]);
} else if(actGradScopedOut && dLdIns[j].isAttached()){
dLdIns[j] = dLdIns[j].detach();
}
} }
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere //Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
@ -264,7 +314,7 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
fn = sameDiff.f().externalErrors(layerOutput); fn = sameDiff.f().externalErrors(layerOutput);
fn.outputVariable(); fn.outputVariable();
this.outputKey = outputVar.getVarName(); this.outputKey = outputVar.name();
} }
} }

View File

@ -26,9 +26,12 @@ import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.AbstractLayer; import org.deeplearning4j.nn.layers.AbstractLayer;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.InferenceSession;
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -84,6 +87,7 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
if (sameDiff == null) { if (sameDiff == null) {
doInit(); doInit();
} }
}
org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf(); org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
bl.validateInput(input); bl.validateInput(input);
@ -103,15 +107,39 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey())); sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
} }
//Configure memory management for SameDiff instance - use DL4J workspaces
String wsNameWorking = workspaceMgr.getWorkspaceName(ArrayType.FF_WORKING_MEM);
String wsNameOutput = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATIONS);
WorkspaceConfiguration confWorking = workspaceMgr.getConfiguration(ArrayType.FF_WORKING_MEM);
WorkspaceConfiguration confOutput = workspaceMgr.getConfiguration(ArrayType.ACTIVATIONS);
boolean actScopedOut = workspaceMgr.isScopedOut(ArrayType.ACTIVATIONS);
Preconditions.checkState(actScopedOut || wsNameOutput != null, "Activations must have a workspace or must be scoped out");
SessionMemMgr mmgr = new DL4JSameDiffMemoryMgr(wsNameWorking, wsNameOutput, confWorking, confOutput);
InferenceSession is = sameDiff.getSessions().get(Thread.currentThread().getId());
if(is == null){
is = new InferenceSession(sameDiff);
sameDiff.getSessions().put(Thread.currentThread().getId(), is);
}
is.setMmgr(mmgr);
Map<String,INDArray> out = sameDiff.output(phMap, outputKey); Map<String,INDArray> out = sameDiff.output(phMap, outputKey);
INDArray result = out.get(outputKey); INDArray result = out.get(outputKey);
//Edge case - identity activation
//TODO there may be a cleaner way to do this...
if(!actScopedOut && !result.data().getParentWorkspace().getId().equals(wsNameOutput)){
result = workspaceMgr.dup(ArrayType.ACTIVATIONS, result);
} else if(actScopedOut && result.isAttached()){
result = result.detach();
}
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere //Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
sameDiff.clearPlaceholders(true); sameDiff.clearPlaceholders(true);
sameDiff.clearOpInputs(); sameDiff.clearOpInputs();
return workspaceMgr.dup(ArrayType.ACTIVATIONS, result); return result;
}
} }
@ -122,6 +150,7 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
Gradient g = new DefaultGradient(); Gradient g = new DefaultGradient();
INDArray dLdIn; INDArray dLdIn;
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
if (sameDiff == null) { if (sameDiff == null) {
doInit(); doInit();
@ -130,6 +159,22 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
//Create when scoped out, to ensure any arrays are not in WS //Create when scoped out, to ensure any arrays are not in WS
sameDiff.createGradFunction(INPUT_KEY); sameDiff.createGradFunction(INPUT_KEY);
} }
}
//Configure memory management for SameDiff instance - use DL4J workspaces
Map<Long,InferenceSession> sessionMap = sameDiff.getFunction("grad").getSessions();
if(!sessionMap.containsKey(Thread.currentThread().getId())){
sessionMap.put(Thread.currentThread().getId(), new InferenceSession(sameDiff.getFunction("grad")));
}
String wsNameWorking = workspaceMgr.getWorkspaceName(ArrayType.BP_WORKING_MEM);
String wsNameActGrad = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATION_GRAD);
WorkspaceConfiguration confWorking = workspaceMgr.getConfiguration(ArrayType.BP_WORKING_MEM);
WorkspaceConfiguration confOutput = workspaceMgr.getConfiguration(ArrayType.ACTIVATION_GRAD);
boolean actGradScopedOut = workspaceMgr.isScopedOut(ArrayType.ACTIVATION_GRAD);
Preconditions.checkState(actGradScopedOut || wsNameActGrad != null, "Activation gradients must have a workspace or be scoped out");
SessionMemMgr mmgr = new DL4JSameDiffMemoryMgr(wsNameWorking, wsNameActGrad, confWorking, confOutput);
sessionMap.get(Thread.currentThread().getId()).setMmgr(mmgr);
org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf(); org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
bl.validateInput(input); bl.validateInput(input);
@ -151,34 +196,26 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
} }
List<String> requiredGrads = new ArrayList<>(paramTable.size() + 1); List<String> requiredGrads = new ArrayList<>(paramTable.size() + 1);
requiredGrads.add(sameDiff.grad(INPUT_KEY).getVarName()); requiredGrads.add(INPUT_KEY);
for(String s : paramTable.keySet()){ requiredGrads.addAll(paramTable.keySet());
requiredGrads.add(sameDiff.grad(s).getVarName());
}
sameDiff.execBackwards(phMap, requiredGrads); Map<String,INDArray> m = sameDiff.calculateGradients(phMap, requiredGrads);
for(String s : paramTable.keySet() ){ for(String s : paramTable.keySet() ){
INDArray sdGrad = sameDiff.grad(s).getArr(); INDArray sdGrad = m.get(s);
INDArray dl4jGrad = gradTable.get(s); INDArray dl4jGrad = gradTable.get(s);
dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS
g.gradientForVariable().put(s, dl4jGrad); g.gradientForVariable().put(s, dl4jGrad);
} }
SDVariable v = sameDiff.grad(INPUT_KEY); dLdIn = m.get(INPUT_KEY);
dLdIn = v.getArr();
if(dLdIn == null && fn.getGradPlaceholderName().equals(v.getVarName())){
//Edge case with lambda layers like identity: SameDiff doesn't store the placeholders
// So, this getArr() can be trying to get placeholder from SameDiff instance, when it's available here
dLdIn = epsilon;
}
}
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere //Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
sameDiff.clearPlaceholders(true); sameDiff.clearPlaceholders(true);
sameDiff.clearOpInputs(); sameDiff.clearOpInputs();
return new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS Pair<Gradient, INDArray> ret = new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS
return ret;
} }
/**Returns the parameters of the neural network as a flattened row vector /**Returns the parameters of the neural network as a flattened row vector
@ -291,7 +328,7 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
fn = sameDiff.f().externalErrors(layerOutput); fn = sameDiff.f().externalErrors(layerOutput);
fn.outputVariable(); fn.outputVariable();
this.outputKey = outputVar.getVarName(); this.outputKey = outputVar.name();
} }
} }

View File

@ -29,9 +29,12 @@ import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.InferenceSession;
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.DataSet;
@ -98,6 +101,25 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
if (sameDiff == null) { if (sameDiff == null) {
doInit(); doInit();
} }
}
//Configure memory management for SameDiff instance - use DL4J workspaces
String wsNameWorking = workspaceMgr.getWorkspaceName(ArrayType.FF_WORKING_MEM);
String wsNameOutput = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATIONS);
WorkspaceConfiguration confWorking = workspaceMgr.getConfiguration(ArrayType.FF_WORKING_MEM);
WorkspaceConfiguration confOutput = workspaceMgr.getConfiguration(ArrayType.ACTIVATIONS);
boolean actScopedOut = workspaceMgr.isScopedOut(ArrayType.ACTIVATIONS);
Preconditions.checkState(actScopedOut || wsNameOutput != null, "Activations must have a workspace or must be scoped out");
SessionMemMgr mmgr = new DL4JSameDiffMemoryMgr(wsNameWorking, wsNameOutput, confWorking, confOutput);
InferenceSession is = sameDiff.getSessions().get(Thread.currentThread().getId());
if(is == null){
is = new InferenceSession(sameDiff);
sameDiff.getSessions().put(Thread.currentThread().getId(), is);
}
is.setMmgr(mmgr);
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration //Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO Find a more efficient solution for this //TODO Find a more efficient solution for this
@ -112,7 +134,7 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
phMap.put(LABELS_KEY, labels); phMap.put(LABELS_KEY, labels);
} }
String s = activations ? layerConf().activationsVertexName() : outputVar.getVarName(); String s = activations ? layerConf().activationsVertexName() : outputVar.name();
INDArray out = sameDiff.outputSingle(phMap, s); INDArray out = sameDiff.outputSingle(phMap, s);
@ -120,16 +142,16 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
sameDiff.clearPlaceholders(true); sameDiff.clearPlaceholders(true);
sameDiff.clearOpInputs(); sameDiff.clearOpInputs();
if(activations) { //Edge case: vertex is just an Identity function, for example
Preconditions.checkNotNull(out, "Activations (result) array for variable \"%s\" was " + //TODO there may be a cleaner way to do this...
"null - error during execution or this variable (as defined by method activationsVertexName()) " + if(!actScopedOut && !out.data().getParentWorkspace().getId().equals(wsNameOutput)){
"does not exist", layerConf().activationsVertexName()); out = workspaceMgr.dup(ArrayType.ACTIVATIONS, out);
return workspaceMgr.dup(ArrayType.ACTIVATIONS, out); } else if(actScopedOut && out.isAttached()){
} else { out = out.detach();
}
return out; return out;
} }
}
}
@Override @Override
@ -147,6 +169,25 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
// (for efficiency, we skip output layer forward pass in MultiLayerNetwork/ComputationGraph) // (for efficiency, we skip output layer forward pass in MultiLayerNetwork/ComputationGraph)
doInit(); doInit();
} }
if(sameDiff.getFunction("grad") == null)
sameDiff.createGradFunction(INPUT_KEY);
}
//Configure memory management for SameDiff instance - use DL4J workspaces
Map<Long,InferenceSession> sessionMap = sameDiff.getFunction("grad").getSessions();
if(!sessionMap.containsKey(Thread.currentThread().getId())){
sessionMap.put(Thread.currentThread().getId(), new InferenceSession(sameDiff.getFunction("grad")));
}
String wsNameWorking = workspaceMgr.getWorkspaceName(ArrayType.BP_WORKING_MEM);
String wsNameActGrad = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATION_GRAD);
WorkspaceConfiguration confWorking = workspaceMgr.getConfiguration(ArrayType.BP_WORKING_MEM);
WorkspaceConfiguration confOutput = workspaceMgr.getConfiguration(ArrayType.ACTIVATION_GRAD);
boolean actGradScopedOut = workspaceMgr.isScopedOut(ArrayType.ACTIVATION_GRAD);
Preconditions.checkState(actGradScopedOut || wsNameActGrad != null, "Activation gradients must have a workspace or be scoped out");
SessionMemMgr mmgr = new DL4JSameDiffMemoryMgr(wsNameWorking, wsNameActGrad, confWorking, confOutput);
sessionMap.get(Thread.currentThread().getId()).setMmgr(mmgr);
if(!sameDiff.hasGradientFunction()) { if(!sameDiff.hasGradientFunction()) {
//Create when scoped out, to ensure any arrays are not in WS //Create when scoped out, to ensure any arrays are not in WS
sameDiff.createGradFunction(INPUT_KEY); sameDiff.createGradFunction(INPUT_KEY);
@ -160,31 +201,38 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
} }
List<String> gradVarNames = new ArrayList<>(); List<String> gradVarNames = new ArrayList<>();
for(String s : paramTable.keySet()){ gradVarNames.addAll(paramTable.keySet());
gradVarNames.add(sameDiff.getVariable(s).getGradient().getVarName()); gradVarNames.add(INPUT_KEY);
}
gradVarNames.add(sameDiff.grad(INPUT_KEY).getVarName());
Map<String,INDArray> phMap = new HashMap<>(); Map<String,INDArray> phMap = new HashMap<>();
phMap.put(INPUT_KEY, input); phMap.put(INPUT_KEY, input);
phMap.put(LABELS_KEY, labels); phMap.put(LABELS_KEY, labels);
sameDiff.execBackwards(phMap, gradVarNames); Map<String,INDArray> grads = sameDiff.calculateGradients(phMap, gradVarNames);
for(String s : paramTable.keySet() ){ for(String s : paramTable.keySet() ){
INDArray sdGrad = sameDiff.grad(s).getArr(); INDArray sdGrad = grads.get(s);
INDArray dl4jGrad = gradTable.get(s); INDArray dl4jGrad = gradTable.get(s);
dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS
g.gradientForVariable().put(s, dl4jGrad); g.gradientForVariable().put(s, dl4jGrad);
if(sdGrad.closeable()){
sdGrad.close();
}
} }
dLdIn = sameDiff.grad(INPUT_KEY).getArr(); dLdIn = grads.get(INPUT_KEY);
}
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere //Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
sameDiff.clearPlaceholders(true); sameDiff.clearPlaceholders(true);
sameDiff.clearOpInputs(); sameDiff.clearOpInputs();
return new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS //TODO there may be a cleaner way to do this...
if(!actGradScopedOut && !dLdIn.data().getParentWorkspace().getId().equals(wsNameActGrad)){
dLdIn = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn);
} else if(actGradScopedOut && dLdIn.isAttached()){
dLdIn = dLdIn.detach();
}
return new Pair<>(g, dLdIn);
} }
/**Returns the parameters of the neural network as a flattened row vector /**Returns the parameters of the neural network as a flattened row vector
@ -297,7 +345,7 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
sameDiff.associateArrayWithVariable(arr, sameDiff.getVariable(e.getKey())); sameDiff.associateArrayWithVariable(arr, sameDiff.getVariable(e.getKey()));
} }
this.outputKey = layerOutput.getVarName(); this.outputKey = layerOutput.name();
} }
} }
@ -308,7 +356,8 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
@Override @Override
public double computeScore(double fullNetRegTerm, boolean training, LayerWorkspaceMgr workspaceMgr) { public double computeScore(double fullNetRegTerm, boolean training, LayerWorkspaceMgr workspaceMgr) {
return (activateHelper(false, workspaceMgr).getDouble(0) + fullNetRegTerm) / input.size(0); INDArray scoreArr = activateHelper(false, workspaceMgr);
return (scoreArr.getDouble(0) + fullNetRegTerm) / input.size(0);
} }
@Override @Override

View File

@ -41,6 +41,7 @@ import org.nd4j.linalg.api.blas.Level1;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.lossfunctions.ILossFunction; import org.nd4j.linalg.lossfunctions.ILossFunction;
@ -552,7 +553,8 @@ public class VariationalAutoencoder implements Layer {
@Override @Override
public int batchSize() { public int batchSize() {
// FIXME: int cast if (input.size(0) > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
return (int) input.size(0); return (int) input.size(0);
} }
@ -862,7 +864,8 @@ public class VariationalAutoencoder implements Layer {
@Override @Override
public int getInputMiniBatchSize() { public int getInputMiniBatchSize() {
// FIXME: int cast if (input.size(0) > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
return (int) input.size(0); return (int) input.size(0);
} }

View File

@ -75,6 +75,7 @@ import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.heartbeat.Heartbeat; import org.nd4j.linalg.heartbeat.Heartbeat;
import org.nd4j.linalg.heartbeat.reports.Environment; import org.nd4j.linalg.heartbeat.reports.Environment;
@ -425,7 +426,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
try(MemoryWorkspace ws = workspaceMgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM)) { try(MemoryWorkspace ws = workspaceMgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM)) {
if (layerWiseConfigurations.getInputPreProcess(layerIdx) != null) { if (layerWiseConfigurations.getInputPreProcess(layerIdx) != null) {
// FIXME: int cast if (input.size(0) > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
outputOfPrevLayer = layerWiseConfigurations.getInputPreProcess(layerIdx).preProcess(outputOfPrevLayer, (int) input.size(0), outputOfPrevLayer = layerWiseConfigurations.getInputPreProcess(layerIdx).preProcess(outputOfPrevLayer, (int) input.size(0),
LayerWorkspaceMgr.noWorkspaces(helperWorkspaces)); LayerWorkspaceMgr.noWorkspaces(helperWorkspaces));
} }
@ -439,7 +441,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
//In 99+% of cases, the input and labels dimension 0 size should be identical //In 99+% of cases, the input and labels dimension 0 size should be identical
//The only real exceptions: space to batch, and batch to space layers //The only real exceptions: space to batch, and batch to space layers
//In those cases, we should base it on the labels size, as this impacts gradient calculation //In those cases, we should base it on the labels size, as this impacts gradient calculation
// FIXME: int cast if (input.size(0) > Integer.MAX_VALUE || labels.size(0) > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
return labels == null ? (int) input.size(0) : (int)labels.size(0); return labels == null ? (int) input.size(0) : (int)labels.size(0);
} }
@ -2074,7 +2077,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
if (endTimeIdx > timeSeriesLength) if (endTimeIdx > timeSeriesLength)
endTimeIdx = timeSeriesLength; endTimeIdx = timeSeriesLength;
// FIXME: int cast if (startTimeIdx > Integer.MAX_VALUE || endTimeIdx > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
INDArray[] subsets = getSubsetsForTbptt((int) startTimeIdx, (int) endTimeIdx, input, labels, INDArray[] subsets = getSubsetsForTbptt((int) startTimeIdx, (int) endTimeIdx, input, labels,
featuresMaskArray, labelsMaskArray); featuresMaskArray, labelsMaskArray);
@ -2211,7 +2215,9 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
public int[] predict(INDArray d) { public int[] predict(INDArray d) {
INDArray output = output(d, Layer.TrainingMode.TEST); INDArray output = output(d, Layer.TrainingMode.TEST);
// FIXME: int cast if (d.size(0) > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
int[] ret = new int[(int) d.size(0)]; int[] ret = new int[(int) d.size(0)];
if (d.isRowVectorOrScalar()) if (d.isRowVectorOrScalar())
ret[0] = Nd4j.getBlasWrapper().iamax(output); ret[0] = Nd4j.getBlasWrapper().iamax(output);
@ -2335,7 +2341,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
org.deeplearning4j.nn.conf.layers.OutputLayer layerConf = org.deeplearning4j.nn.conf.layers.OutputLayer layerConf =
(org.deeplearning4j.nn.conf.layers.OutputLayer) getOutputLayer().conf().getLayer(); (org.deeplearning4j.nn.conf.layers.OutputLayer) getOutputLayer().conf().getLayer();
// FIXME: int cast if (layerConf.getNOut() > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
fit(examples, FeatureUtil.toOutcomeMatrix(labels, (int) layerConf.getNOut())); fit(examples, FeatureUtil.toOutcomeMatrix(labels, (int) layerConf.getNOut()));
} }
@ -2584,7 +2591,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
INDArray inputToOutputLayer = outputOfLayerDetached(training, FwdPassType.STANDARD,layers.length-2, data.getFeatures(), INDArray inputToOutputLayer = outputOfLayerDetached(training, FwdPassType.STANDARD,layers.length-2, data.getFeatures(),
data.getFeaturesMaskArray(), data.getLabelsMaskArray(), null); data.getFeaturesMaskArray(), data.getLabelsMaskArray(), null);
// FIXME: int cast if (data.getFeatures().size(0) > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
IOutputLayer ol = (IOutputLayer) getOutputLayer(); IOutputLayer ol = (IOutputLayer) getOutputLayer();
if (getLayerWiseConfigurations().getInputPreProcess(layers.length - 1) != null) { if (getLayerWiseConfigurations().getInputPreProcess(layers.length - 1) != null) {
inputToOutputLayer = getLayerWiseConfigurations().getInputPreProcess(layers.length - 1) inputToOutputLayer = getLayerWiseConfigurations().getInputPreProcess(layers.length - 1)
@ -2647,7 +2655,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
IOutputLayer ol = (IOutputLayer) getOutputLayer(); IOutputLayer ol = (IOutputLayer) getOutputLayer();
if(layerWiseConfigurations.getInputPreProcess(layers.length-1) != null){ if(layerWiseConfigurations.getInputPreProcess(layers.length-1) != null){
// FIXME: int cast if (data.getFeatures().size(0) > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
inputLast = layerWiseConfigurations.getInputPreProcess(layers.length-1).preProcess(inputLast, inputLast = layerWiseConfigurations.getInputPreProcess(layers.length-1).preProcess(inputLast,
(int) data.getFeatures().size(0), mgr); (int) data.getFeatures().size(0), mgr);
} }
@ -2811,7 +2820,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
throw new IllegalArgumentException( throw new IllegalArgumentException(
"Invalid input: length 0 (shape: " + Arrays.toString(input.shape()) + ")"); "Invalid input: length 0 (shape: " + Arrays.toString(input.shape()) + ")");
// FIXME: int cast if (input.size(0) > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
setInputMiniBatchSize((int) input.size(0)); setInputMiniBatchSize((int) input.size(0));
} }
} }
@ -3086,7 +3096,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
if(!conf().isMiniBatch()) if(!conf().isMiniBatch())
return 1; return 1;
// FIXME: int cast if (input.size(0) > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
return (int) input.size(0); return (int) input.size(0);
} }
@ -3256,7 +3267,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
public void setLayerMaskArrays(INDArray featuresMaskArray, INDArray labelsMaskArray) { public void setLayerMaskArrays(INDArray featuresMaskArray, INDArray labelsMaskArray) {
if (featuresMaskArray != null) { if (featuresMaskArray != null) {
// FIXME: int cast if (featuresMaskArray.size(0) > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
//New approach: use feedForwardMaskArray method //New approach: use feedForwardMaskArray method
feedForwardMaskArray(featuresMaskArray, MaskState.Active, (int) featuresMaskArray.size(0)); feedForwardMaskArray(featuresMaskArray, MaskState.Active, (int) featuresMaskArray.size(0));
@ -3438,7 +3450,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
val startTimeIdx = i * fwdLen; val startTimeIdx = i * fwdLen;
val endTimeIdx = Math.min(startTimeIdx + fwdLen, tsLength); val endTimeIdx = Math.min(startTimeIdx + fwdLen, tsLength);
// FIXME: int cast if (endTimeIdx > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
INDArray[] subsets = getSubsetsForTbptt(startTimeIdx, (int) endTimeIdx, features, labels, fMask, lMask); INDArray[] subsets = getSubsetsForTbptt(startTimeIdx, (int) endTimeIdx, features, labels, fMask, lMask);
setLayerMaskArrays(subsets[2], subsets[3]); setLayerMaskArrays(subsets[2], subsets[3]);
@ -3943,7 +3956,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
} }
FeedForwardLayer ffl = (FeedForwardLayer) conf; FeedForwardLayer ffl = (FeedForwardLayer) conf;
// FIXME: int cast if (ffl.getNOut() > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
return (int) ffl.getNOut(); return (int) ffl.getNOut();
} }
@ -3969,7 +3983,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
} }
FeedForwardLayer ffl = (FeedForwardLayer) conf; FeedForwardLayer ffl = (FeedForwardLayer) conf;
// FIXME: int cast if (ffl.getNIn() > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
return (int) ffl.getNIn(); return (int) ffl.getNIn();
} }

View File

@ -22,6 +22,7 @@ import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.IWeightInit;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
import java.util.ArrayList; import java.util.ArrayList;
@ -108,7 +109,8 @@ public class VariationalAutoencoderParamInitializer extends DefaultParamInitiali
} }
//Between last decoder layer and parameters for p(x|z): //Between last decoder layer and parameters for p(x|z):
// FIXME: int cast if (nIn > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
val nDistributionParams = layer.getOutputDistribution().distributionInputSize((int) nIn); val nDistributionParams = layer.getOutputDistribution().distributionInputSize((int) nIn);
val lastDecLayerSize = decoderLayerSizes[decoderLayerSizes.length - 1]; val lastDecLayerSize = decoderLayerSizes[decoderLayerSizes.length - 1];
paramCount += (lastDecLayerSize + 1) * nDistributionParams; paramCount += (lastDecLayerSize + 1) * nDistributionParams;
@ -294,7 +296,8 @@ public class VariationalAutoencoderParamInitializer extends DefaultParamInitiali
} }
//Finally, p(x|z): //Finally, p(x|z):
// FIXME: int cast if (nIn > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
int nDistributionParams = layer.getOutputDistribution().distributionInputSize((int) nIn); int nDistributionParams = layer.getOutputDistribution().distributionInputSize((int) nIn);
int pxzWeightCount = decoderLayerSizes[decoderLayerSizes.length - 1] * nDistributionParams; int pxzWeightCount = decoderLayerSizes[decoderLayerSizes.length - 1] * nDistributionParams;
INDArray pxzWeightView = INDArray pxzWeightView =
@ -402,7 +405,8 @@ public class VariationalAutoencoderParamInitializer extends DefaultParamInitiali
} }
//Finally, p(x|z): //Finally, p(x|z):
// FIXME: int cast if (nIn > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
int nDistributionParams = layer.getOutputDistribution().distributionInputSize((int) nIn); int nDistributionParams = layer.getOutputDistribution().distributionInputSize((int) nIn);
int pxzWeightCount = decoderLayerSizes[decoderLayerSizes.length - 1] * nDistributionParams; int pxzWeightCount = decoderLayerSizes[decoderLayerSizes.length - 1] * nDistributionParams;
INDArray pxzWeightView = INDArray pxzWeightView =

View File

@ -30,6 +30,7 @@ import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2; import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.ArrayType;
@ -111,7 +112,8 @@ public abstract class BaseMultiLayerUpdater<T extends Model> implements Updater
if (currentBlock == null || !UpdaterUtils.updaterConfigurationsEquals(lastLayer, lastVariable, if (currentBlock == null || !UpdaterUtils.updaterConfigurationsEquals(lastLayer, lastVariable,
layers[i], var)) { layers[i], var)) {
// FIXME: int cast if (paramsViewSoFar + paramSizeThisVariable > Integer.MAX_VALUE || paramsViewSoFar + paramSizeThisVariable > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
//Create a new block //Create a new block
List<UpdaterBlock.ParamState> list = new ArrayList<>(); List<UpdaterBlock.ParamState> list = new ArrayList<>();
list.add(new UpdaterBlock.ParamState(layers[i], var, paramsViewSoFar, list.add(new UpdaterBlock.ParamState(layers[i], var, paramsViewSoFar,
@ -122,9 +124,11 @@ public abstract class BaseMultiLayerUpdater<T extends Model> implements Updater
updaterBlocks.add(currentBlock); updaterBlocks.add(currentBlock);
} else { } else {
// FIXME: int cast long newOffset = currentBlock.getParamOffsetEnd() + paramSizeThisVariable;
if (newOffset > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
//Add to existing updater block //Add to existing updater block
currentBlock.setParamOffsetEnd((int) (currentBlock.getParamOffsetEnd() + paramSizeThisVariable)); currentBlock.setParamOffsetEnd((int) newOffset);
currentBlock.setUpdaterViewOffsetEnd( currentBlock.setUpdaterViewOffsetEnd(
currentBlock.getUpdaterViewOffsetEnd() + updaterStateSizeThisVariable); currentBlock.getUpdaterViewOffsetEnd() + updaterStateSizeThisVariable);
currentBlock.getLayersAndVariablesInBlock() currentBlock.getLayersAndVariablesInBlock()

View File

@ -25,6 +25,7 @@ import java.io.FileOutputStream;
import java.io.IOException; import java.io.IOException;
import java.io.OutputStream; import java.io.OutputStream;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import java.util.List;
/** /**
@ -37,7 +38,83 @@ public class CollectScoresIterationListener extends BaseTrainingListener {
private int frequency; private int frequency;
private int iterationCount = 0; private int iterationCount = 0;
private List<Pair<Integer, Double>> scoreVsIter = new ArrayList<>(); //private List<Pair<Integer, Double>> scoreVsIter = new ArrayList<>();
public static class ScoreStat {
public static final int BUCKET_LENGTH = 10000;
private int position = 0;
private int bucketNumber = 1;
private List<long[]> indexes;
private List<double[]> scores;
public ScoreStat() {
indexes = new ArrayList<>(1);
indexes.add(new long[BUCKET_LENGTH]);
scores = new ArrayList<>(1);
scores.add(new double[BUCKET_LENGTH]);
}
public List<long[]> getIndexes() {
return indexes;
}
public List<double[]> getScores() {
return scores;
}
public long[] getEffectiveIndexes() {
return Arrays.copyOfRange(indexes.get(0), 0, position);
}
public double[] getEffectiveScores() {
return Arrays.copyOfRange(scores.get(0), 0, position);
}
/*
Originally scores array is initialized with BUCKET_LENGTH size.
When data doesn't fit there - arrays size is increased for BUCKET_LENGTH,
old data is copied and bucketNumber (counter of reallocations) being incremented.
If we got more score points than MAX_VALUE - they are put to another item of scores list.
*/
private void reallocateGuard() {
if (position >= BUCKET_LENGTH * bucketNumber) {
long fullLength = (long)BUCKET_LENGTH * bucketNumber;
if (position == Integer.MAX_VALUE || fullLength >= Integer.MAX_VALUE) {
position = 0;
long[] newIndexes = new long[BUCKET_LENGTH];
double[] newScores = new double[BUCKET_LENGTH];
indexes.add(newIndexes);
scores.add(newScores);
}
else {
long[] newIndexes = new long[(int)fullLength + BUCKET_LENGTH];
double[] newScores = new double[(int)fullLength + BUCKET_LENGTH];
System.arraycopy(indexes.get(indexes.size()-1), 0, newIndexes, 0, (int)fullLength);
System.arraycopy(scores.get(scores.size()-1), 0, newScores, 0, (int)fullLength);
scores.remove(scores.size()-1);
indexes.remove(indexes.size()-1);
int lastIndex = scores.size() == 0 ? 0 : scores.size()-1;
scores.add(lastIndex, newScores);
indexes.add(lastIndex, newIndexes);
}
bucketNumber += 1;
}
}
public void addScore(long index, double score) {
reallocateGuard();
scores.get(scores.size() - 1)[position] = score;
indexes.get(scores.size() - 1)[position] = index;
position += 1;
}
}
ScoreStat scoreVsIter = new ScoreStat();
/** /**
* Constructor for collecting scores with default saving frequency of 1 * Constructor for collecting scores with default saving frequency of 1
@ -60,11 +137,12 @@ public class CollectScoresIterationListener extends BaseTrainingListener {
public void iterationDone(Model model, int iteration, int epoch) { public void iterationDone(Model model, int iteration, int epoch) {
if (++iterationCount % frequency == 0) { if (++iterationCount % frequency == 0) {
double score = model.score(); double score = model.score();
scoreVsIter.add(new Pair<>(iterationCount, score)); scoreVsIter.reallocateGuard();
scoreVsIter.addScore(iteration, score);
} }
} }
public List<Pair<Integer, Double>> getScoreVsIter() { public ScoreStat getScoreVsIter() {
return scoreVsIter; return scoreVsIter;
} }
@ -84,8 +162,16 @@ public class CollectScoresIterationListener extends BaseTrainingListener {
public void exportScores(OutputStream outputStream, String delimiter) throws IOException { public void exportScores(OutputStream outputStream, String delimiter) throws IOException {
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
sb.append("Iteration").append(delimiter).append("Score"); sb.append("Iteration").append(delimiter).append("Score");
for (Pair<Integer, Double> p : scoreVsIter) { int largeBuckets = scoreVsIter.indexes.size();
sb.append("\n").append(p.getFirst()).append(delimiter).append(p.getSecond()); for (int j = 0; j < largeBuckets; ++j) {
long[] indexes = scoreVsIter.indexes.get(j);
double[] scores = scoreVsIter.scores.get(j);
int effectiveLength = (j < largeBuckets -1) ? indexes.length : scoreVsIter.position;
for (int i = 0; i < effectiveLength; ++i) {
sb.append("\n").append(indexes[i]).append(delimiter).append(scores[i]);
}
} }
outputStream.write(sb.toString().getBytes("UTF-8")); outputStream.write(sb.toString().getBytes("UTF-8"));
} }

View File

@ -29,6 +29,7 @@ import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastCopyOp; import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastCopyOp;
import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import java.util.Arrays; import java.util.Arrays;
@ -62,10 +63,9 @@ public class Convolution1DUtils {
* @param dilation Kernel dilation * @param dilation Kernel dilation
* @return Output size (width) * @return Output size (width)
*/ */
public static int getOutputSize(int inH, int kernel, int strides, int padding, public static long getOutputSize(long inH, int kernel, int strides, int padding,
ConvolutionMode convolutionMode, int dilation) { ConvolutionMode convolutionMode, int dilation) {
// FIXME: int cast long eKernel = effectiveKernelSize(kernel, dilation);
int eKernel = effectiveKernelSize(kernel, dilation);
if (convolutionMode == ConvolutionMode.Same) { if (convolutionMode == ConvolutionMode.Same) {
return (int) Math.ceil(inH / ((double) strides)); return (int) Math.ceil(inH / ((double) strides));
} }
@ -85,7 +85,8 @@ public class Convolution1DUtils {
*/ */
public static int getOutputSize(INDArray inputData, int kernel, int strides, int padding, public static int getOutputSize(INDArray inputData, int kernel, int strides, int padding,
ConvolutionMode convolutionMode, int dilation) { ConvolutionMode convolutionMode, int dilation) {
// FIXME: int cast if (inputData.size(2) > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
int inH = (int) inputData.size(2); int inH = (int) inputData.size(2);
int eKernel = effectiveKernelSize(kernel, dilation); int eKernel = effectiveKernelSize(kernel, dilation);
boolean atrous = (eKernel == kernel); boolean atrous = (eKernel == kernel);

View File

@ -61,15 +61,14 @@ public class Convolution3DUtils {
ConvolutionMode convolutionMode, int[] dilation, boolean isNCDHW) { ConvolutionMode convolutionMode, int[] dilation, boolean isNCDHW) {
// NCDHW vs. NDHWC // NCDHW vs. NDHWC
int inD = (int) (isNCDHW ? inputData.size(2) : inputData.size(1)); long inD = (isNCDHW ? inputData.size(2) : inputData.size(1));
int inH = (int) (isNCDHW ? inputData.size(3) : inputData.size(2)); long inH = (isNCDHW ? inputData.size(3) : inputData.size(2));
int inW = (int) (isNCDHW ? inputData.size(4) : inputData.size(3)); long inW = (isNCDHW ? inputData.size(4) : inputData.size(3));
int[] eKernel = effectiveKernelSize(kernel, dilation); int[] eKernel = effectiveKernelSize(kernel, dilation);
boolean atrous = (eKernel == kernel); boolean atrous = (eKernel == kernel);
// FIXME: int cast val inShape = new long[]{inD, inH, inW};
val inShape = new int[]{inD, inH, inW};
validateShapes(ArrayUtil.toInts(inputData.shape()), eKernel, strides, padding, convolutionMode, dilation, inShape, atrous); validateShapes(ArrayUtil.toInts(inputData.shape()), eKernel, strides, padding, convolutionMode, dilation, inShape, atrous);
if (convolutionMode == ConvolutionMode.Same) { if (convolutionMode == ConvolutionMode.Same) {
@ -80,16 +79,16 @@ public class Convolution3DUtils {
return new int[]{outD, outH, outW}; return new int[]{outD, outH, outW};
} }
int outD = (inD - eKernel[0] + 2 * padding[0]) / strides[0] + 1; int outD = ((int)inD - eKernel[0] + 2 * padding[0]) / strides[0] + 1;
int outH = (inH - eKernel[1] + 2 * padding[1]) / strides[1] + 1; int outH = ((int)inH - eKernel[1] + 2 * padding[1]) / strides[1] + 1;
int outW = (inW - eKernel[2] + 2 * padding[2]) / strides[2] + 1; int outW = ((int)inW - eKernel[2] + 2 * padding[2]) / strides[2] + 1;
return new int[]{outD, outH, outW}; return new int[]{outD, outH, outW};
} }
private static void validateShapes(int[] inputDataShape, int[] eKernel, int[] strides, int[] padding, private static void validateShapes(int[] inputDataShape, int[] eKernel, int[] strides, int[] padding,
ConvolutionMode convolutionMode, int[] dilation, int[] inShape, ConvolutionMode convolutionMode, int[] dilation, long[] inShape,
boolean atrous) { boolean atrous) {
String[] dims = new String[]{"depth", "height", "width"}; String[] dims = new String[]{"depth", "height", "width"};

Some files were not shown because too many files have changed in this diff Show More