diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestBatchNormBp.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestBatchNormBp.java new file mode 100644 index 000000000..54a47eead --- /dev/null +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestBatchNormBp.java @@ -0,0 +1,107 @@ +package org.deeplearning4j; + +import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.WorkspaceMode; +import org.deeplearning4j.nn.conf.layers.BatchNormalization; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.nn.layers.mkldnn.MKLDNNBatchNormHelper; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.junit.Test; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.primitives.Pair; + +import java.lang.reflect.Field; + +import static junit.framework.TestCase.*; + +public class TestBatchNormBp { + + @Test + public void test(){ + Nd4j.getRandom().setSeed(12345); +// INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 4, 4); + INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 15, 15); + INDArray mean = in.mean(0, 2, 3); //Nd4j.rand(DataType.FLOAT, 3); + INDArray var = in.var(0, 2, 3); //Nd4j.rand(DataType.FLOAT, 3); + INDArray eps = Nd4j.rand(DataType.FLOAT, in.shape()); +// INDArray gamma = Nd4j.ones(DataType.FLOAT, 3); +// INDArray beta = Nd4j.zeros(DataType.FLOAT, 3); + INDArray gamma = Nd4j.rand(DataType.FLOAT, 3); + INDArray beta = Nd4j.rand(DataType.FLOAT, 3); + double e = 1e-5; + + INDArray dLdIn = in.ulike(); + INDArray dLdm = mean.ulike(); + INDArray dLdv = var.ulike(); + INDArray dLdg = gamma.ulike(); + INDArray dLdb = beta.ulike(); + + DynamicCustomOp op = DynamicCustomOp.builder("batchnorm_bp") + .addInputs(in, mean, var, eps, gamma, beta) + .addIntegerArguments( + 1, //Apply scale + 1, //Apply beta + 1) //Axis (NCHW) + .addFloatingPointArguments(e) + .addOutputs(dLdIn, dLdm, dLdv, dLdg, dLdb) + .build(); + + Nd4j.exec(op); + System.out.println(dLdIn); + } + + @Test + public void compareImpls() throws Exception { + + Nd4j.getRandom().setSeed(12345); + INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 15, 15); + INDArray mean = in.mean(0, 2, 3).reshape(1,3); + INDArray var = in.var(0, 2, 3).reshape(1,3); + INDArray eps = Nd4j.rand(DataType.FLOAT, in.shape()); + INDArray gamma = Nd4j.rand(DataType.FLOAT, 1,3); + INDArray beta = Nd4j.rand(DataType.FLOAT, 1,3); + double e = 1e-3; + + INDArray dLdIn = in.ulike(); + INDArray dLdm = mean.ulike(); + INDArray dLdv = var.ulike(); + INDArray dLdg = gamma.ulike(); + INDArray dLdb = beta.ulike(); + + + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .inferenceWorkspaceMode(WorkspaceMode.NONE) + .trainingWorkspaceMode(WorkspaceMode.NONE) + .list() + .layer(new BatchNormalization.Builder().nIn(3).nOut(3).build()) + .build(); + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + org.deeplearning4j.nn.layers.normalization.BatchNormalization bn = (org.deeplearning4j.nn.layers.normalization.BatchNormalization) net.getLayer(0); + assertNotNull(bn.getHelper()); + Field f = bn.getClass().getDeclaredField("helper"); + f.setAccessible(true); + f.set(bn, null); + assertNull(bn.getHelper()); + + + MKLDNNBatchNormHelper h = new MKLDNNBatchNormHelper(DataType.FLOAT); + + net.output(in, true); + bn.setInput(in, LayerWorkspaceMgr.noWorkspaces()); + Pair p = net.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces()); + + h.preOutput(in, true, new int[]{1,3}, gamma, beta, mean, var, 0.5, e, LayerWorkspaceMgr.noWorkspaces()); + Pair pmkl = h.backpropGradient(in, eps, new int[]{1,3}, gamma, beta, dLdg, dLdb, e, LayerWorkspaceMgr.noWorkspaces()); + + INDArray dldin_dl4j = p.getSecond(); + + System.out.println("dl4j == mkldnn: " + p.getSecond().equals(pmkl.getSecond())); + } + +} diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/mkldnn/ValidateMKLDNN.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/mkldnn/ValidateMKLDNN.java index a04553c37..7e3ae6720 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/mkldnn/ValidateMKLDNN.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/mkldnn/ValidateMKLDNN.java @@ -23,10 +23,13 @@ import org.deeplearning4j.datasets.iterator.impl.SingletonDataSetIterator; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.WorkspaceMode; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; +import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.junit.Ignore; import org.junit.Test; import org.nd4j.linalg.activations.Activation; @@ -36,10 +39,13 @@ import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.lossfunctions.LossFunctions; +import org.nd4j.linalg.primitives.Pair; +import java.lang.reflect.Field; import java.util.Arrays; import java.util.Collections; +import static junit.framework.TestCase.*; import static org.junit.Assume.assumeTrue; public class ValidateMKLDNN extends BaseDL4JTest { @@ -148,7 +154,7 @@ public class ValidateMKLDNN extends BaseDL4JTest { .padding(0, 0) .nOut(3) .build()) - .layer(new BatchNormalization.Builder().cudnnAllowFallback(false).build()) + .layer(new BatchNormalization.Builder().helperAllowFallback(false)/*.eps(0)*/.build()) .layer(new ConvolutionLayer.Builder().activation(Activation.TANH) .kernelSize(kernel) .stride(stride) @@ -256,4 +262,54 @@ public class ValidateMKLDNN extends BaseDL4JTest { } } } + + @Test + public void compareBatchNormBackward() throws Exception { + + Nd4j.getRandom().setSeed(12345); + INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 15, 15); + INDArray mean = in.mean(0, 2, 3).reshape(1,3); + INDArray var = in.var(0, 2, 3).reshape(1,3); + INDArray eps = Nd4j.rand(DataType.FLOAT, in.shape()); + INDArray gamma = Nd4j.rand(DataType.FLOAT, 1,3); + INDArray beta = Nd4j.rand(DataType.FLOAT, 1,3); + double e = 1e-3; + + INDArray dLdIn = in.ulike(); + INDArray dLdm = mean.ulike(); + INDArray dLdv = var.ulike(); + INDArray dLdg = gamma.ulike(); + INDArray dLdb = beta.ulike(); + + + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .inferenceWorkspaceMode(WorkspaceMode.NONE) + .trainingWorkspaceMode(WorkspaceMode.NONE) + .list() + .layer(new BatchNormalization.Builder().nIn(3).nOut(3).build()) + .build(); + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + org.deeplearning4j.nn.layers.normalization.BatchNormalization bn = (org.deeplearning4j.nn.layers.normalization.BatchNormalization) net.getLayer(0); + assertNotNull(bn.getHelper()); + System.out.println(bn.getHelper()); + + net.output(in, true); + bn.setInput(in, LayerWorkspaceMgr.noWorkspaces()); + Pair pcudnn = net.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces()); + + Field f = bn.getClass().getDeclaredField("helper"); + f.setAccessible(true); + f.set(bn, null); + assertNull(bn.getHelper()); + + net.output(in, true); + bn.setInput(in, LayerWorkspaceMgr.noWorkspaces()); + Pair p = net.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces()); + + INDArray dldin_dl4j = p.getSecond(); + INDArray dldin_helper = pcudnn.getSecond(); + + assertTrue(dldin_dl4j.equalsWithEps(dldin_helper, 1e-5)); + } } diff --git a/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/normalization/CudnnBatchNormalizationHelper.java b/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/normalization/CudnnBatchNormalizationHelper.java index d296e78a4..aad549ad0 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/normalization/CudnnBatchNormalizationHelper.java +++ b/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/normalization/CudnnBatchNormalizationHelper.java @@ -123,7 +123,7 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba } @Override - public Pair backpropGradient(INDArray input, INDArray epsilon, int[] shape, INDArray gamma, + public Pair backpropGradient(INDArray input, INDArray epsilon, int[] shape, INDArray gamma, INDArray beta, INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr layerWorkspaceMgr) { this.eps = eps; val miniBatch = (int) input.size(0); @@ -189,7 +189,7 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba Pointer varCacheData = allocator.getPointer(varCache, context); checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream()))); - checkCudnn(cudnnBatchNormalizationBackward(cudnnContext, batchNormMode, alpha, beta, alpha, alpha, + checkCudnn(cudnnBatchNormalizationBackward(cudnnContext, batchNormMode, alpha, this.beta, alpha, alpha, cudnnContext.srcTensorDesc, srcData, cudnnContext.deltaTensorDesc, epsData, cudnnContext.dstTensorDesc, dstData, cudnnContext.gammaBetaTensorDesc, gammaData, dGammaData, dBetaData, eps, meanCacheData, varCacheData)); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNBatchNormHelper.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNBatchNormHelper.java index cec3b5297..0d9ae18e7 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNBatchNormHelper.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNBatchNormHelper.java @@ -16,21 +16,28 @@ package org.deeplearning4j.nn.layers.mkldnn; +import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper; +import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm; +import org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNormDerivative; import org.nd4j.linalg.api.ops.impl.summarystats.Variance; +import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.util.ArrayUtil; +import java.util.ArrayList; import java.util.Collections; +import java.util.List; import java.util.Map; /** @@ -57,27 +64,53 @@ public class MKLDNNBatchNormHelper implements BatchNormalizationHelper { @Override public Pair backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, - INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr workspaceMgr) { - //2019-02-14: Backprop disabled pending fixes. https://github.com/deeplearning4j/deeplearning4j/issues/7166 - //Also no MKL-DNN implemented for backprop anyway + INDArray beta, INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr workspaceMgr) { + if(input.dataType() != DataType.FLOAT) + return null; //MKL-DNN only supports float /* - INDArray[] in = gamma == null ? new INDArray[]{input, mean, var, epsilon} : new INDArray[]{input, mean, var, gamma, beta, epsilon}; + //TODO FIXME - AB 2019/11/01 - https://github.com/eclipse/deeplearning4j/issues/8335 + List args = new ArrayList<>(); + args.add(input); + args.add(meanCache); + args.add(varCache); + args.add(epsilon); + if(gamma != null) + args.add(gamma.reshape(gamma.length())); + if(beta != null) + args.add(beta.reshape(beta.length())); - INDArray gradAtInput = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), input.shape()); - INDArray[] out = gamma == null ? new INDArray[]{gradAtInput, } - - BatchNormDerivative bn = BatchNormDerivative.derivativeBuilder() - .applyBeta(gamma != null) - .applyGamma(gamma != null) - .axis(new int[]{1}) //4d: is channels: NCHW; 2d: is nIn - axis 1 in both cases - .epsilon(eps) - .inputArrays(in) - .outputArrays(new INDArray[]{out}) + DynamicCustomOp op = DynamicCustomOp.builder("batchnorm_bp") + .addInputs(args.toArray(new INDArray[0])) + .addIntegerArguments( + gamma == null ? 0 : 1, //Apply scale + beta == null ? 0 : 1, //Apply beta + 1) //Axis (NCHW) + .addFloatingPointArguments(eps) .build(); - Nd4j.exec(bn); - */ + INDArray epsAtInput = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape()); + INDArray dLdm = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, meanCache.dataType(), meanCache.shape()); + INDArray dLdv = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, meanCache.dataType(), meanCache.shape()); + + op.setOutputArgument(0, epsAtInput); + op.setOutputArgument(1, dLdm); + op.setOutputArgument(2, dLdv); + if(dGammaView != null) { + //Both are always null/not null simultaneously + op.setOutputArgument(3, dGammaView.reshape(dGammaView.length())); + op.setOutputArgument(4, dBetaView.reshape(dBetaView.length())); + } + + + Nd4j.exec(op); + + Gradient g = new DefaultGradient(); + g.setGradientFor(BatchNormalizationParamInitializer.GAMMA, dGammaView); + g.setGradientFor(BatchNormalizationParamInitializer.BETA, dBetaView); + + return new Pair<>(g, epsAtInput); + */ return null; } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNLSTMHelper.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNLSTMHelper.java new file mode 100644 index 000000000..ed9a2ef15 --- /dev/null +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNLSTMHelper.java @@ -0,0 +1,168 @@ +package org.deeplearning4j.nn.layers.mkldnn; + +import org.deeplearning4j.nn.api.Layer; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LSTM; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.nn.layers.recurrent.FwdPassReturn; +import org.deeplearning4j.nn.layers.recurrent.LSTMHelper; +import org.deeplearning4j.nn.workspace.ArrayType; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.activations.impl.*; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.shape.LongShapeDescriptor; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.BooleanIndexing; +import org.nd4j.linalg.indexing.conditions.Conditions; +import org.nd4j.linalg.primitives.Pair; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +public class MKLDNNLSTMHelper implements LSTMHelper { + @Override + public boolean checkSupported(IActivation gateActivationFn, IActivation activationFn, boolean hasPeepholeConnections) { + //TODO check other activation functions for MKLDNN + return gateActivationFn instanceof ActivationSigmoid && activationFn instanceof ActivationTanH && BaseMKLDNNHelper.mklDnnEnabled(); + } + + @Override + public Pair backpropGradient(NeuralNetConfiguration conf, IActivation gateActivationFn, INDArray input, + INDArray recurrentWeights, INDArray inputWeights, INDArray epsilon, boolean truncatedBPTT, + int tbpttBackwardLength, FwdPassReturn fwdPass, boolean forwards, String inputWeightKey, + String recurrentWeightKey, String biasWeightKey, Map gradientViews, + INDArray maskArray, boolean hasPeepholeConnections, LayerWorkspaceMgr workspaceMgr) { + //Not yet implemented/supported + return null; + } + + @Override + public FwdPassReturn activate(Layer layer, NeuralNetConfiguration conf, IActivation gateActivationFn, INDArray input, + INDArray recurrentWeights, INDArray inputWeights, INDArray biases, boolean training, + INDArray prevOutputActivations, INDArray prevMemCellState, boolean forBackprop, boolean forwards, + String inputWeightKey, INDArray maskArray, boolean hasPeepholeConnections, LayerWorkspaceMgr workspaceMgr) { + + /* + DL4J data format: [bS, nIn, sL] - dataFormat == 2, directionMode == 0 (forward) + Inputs: + x = [bS, nIn, sL] + Wx = [nIn, 4*nOut] + Wr = [nOut, 4*nOut] + Wp = [3*nOut] Optional peephole weights + b = [4*nOut] + seqLen = [bS] + initialOut = [bs, nOut] + initialCell = [bs, nOut] + + Outputs: + out = [bS, nOut, sL] + outLast = [bs, nOut] + cellLast = [bs,nOut] + + Gates order: input, forget, input modulation, output + + + const auto hasBiases = B_ARG(0); // indicates whether biases array is provided + const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided + const auto hasInitH = B_ARG(2); // indicates whether initial output is provided + const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided + const auto hasPH = B_ARG(4); // indicates whether peephole connections are present + const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1} + const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) + const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) + */ + + INDArray b1d = biases.reshape(biases.length()); + INDArray seqLen = null; + if(maskArray != null){ + seqLen = BooleanIndexing.firstIndex(maskArray, Conditions.equals(0), 1); //First 0 along dimension 1 (for [mb, seqLen]) + } + + List args = new ArrayList<>(); + args.add(input); + args.add(inputWeights); + args.add(recurrentWeights); + if(hasPeepholeConnections){ + throw new IllegalStateException("Not yet implemented"); + } + args.add(b1d); + if(seqLen != null) + args.add(seqLen); + if(prevOutputActivations != null) + args.add(prevOutputActivations); + if(prevMemCellState != null) + args.add(prevMemCellState); + + IActivation a = ((LSTM)conf.getLayer()).getActivationFn(); + + DynamicCustomOp op = DynamicCustomOp.builder("lstmLayer") + .addInputs(args.toArray(new INDArray[0])) + .addBooleanArguments( + true, //hasBiases + seqLen != null, //hasSeqLen + prevOutputActivations != null, //hasInitH + prevMemCellState != null, //hasInitC + hasPeepholeConnections, //hasPh + true, //retFullSeq + true, //retLastH + true //retLastC + ) + .addIntegerArguments( + 2, //data format: 2 = [bS, nIn, sL] + 0, //direction: 0 = forward + activationToArg(gateActivationFn), //Gate activation + activationToArg(a), //Cell state activation + activationToArg(a) //Output activation (same as cell in DL4J) + ) + .build(); + + List outShapes = op.calculateOutputShape(); + + for(LongShapeDescriptor lsd : outShapes){ + INDArray arr = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, lsd.dataType(), lsd.getShape(), lsd.getOrder()); + op.addOutputArgument(arr); + } + + FwdPassReturn f = new FwdPassReturn(); + f.fwdPassOutput = op.getOutputArgument(0); + f.lastAct = op.getOutputArgument(1); + f.lastMemCell = op.getOutputArgument(2); + + return f; + } + + @Override + public Map helperMemoryUse() { + return Collections.emptyMap(); + } + + private int activationToArg(IActivation a){ + //0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus + if(a instanceof ActivationTanH) + return 0; + if(a instanceof ActivationReLU) + return 1; + if(a instanceof ActivationSigmoid) + return 2; + if(a instanceof ActivationIdentity) + return 3; + if(a instanceof ActivationLReLU) + return 4; + if(a instanceof ActivationThresholdedReLU) + return 5; + if(a instanceof ActivationHardSigmoid) + return 7; + if(a instanceof ActivationELU) + return 8; + if(a instanceof ActivationSoftSign) + return 9; + if(a instanceof ActivationSoftPlus) + return 10; + throw new IllegalStateException("Unknown or not supported activation function: " + a); + } +} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalization.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalization.java index 7f5013b1c..8c8f329ea 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalization.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalization.java @@ -118,6 +118,7 @@ public class BatchNormalization extends BaseLayer ret = null; try { - ret = helper.backpropGradient(in, eps, shape, gamma, dGammaView, dBetaView, + ret = helper.backpropGradient(in, eps, shape, gamma, beta, dGammaView, dBetaView, layerConf.getEps(), workspaceMgr); } catch (ND4JOpProfilerException e){ throw e; //NaN panic etc for debugging } catch (Throwable t){ - if(t.getMessage().contains("Failed to allocate")){ + if(t.getMessage() != null && t.getMessage().contains("Failed to allocate")){ //This is a memory exception - don't fallback to built-in implementation throw t; } @@ -451,7 +453,7 @@ public class BatchNormalization extends BaseLayer backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, - INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr workspaceMgr); + Pair backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, INDArray beta, + INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr workspaceMgr); INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean, INDArray var, double decay, double eps, LayerWorkspaceMgr workspaceMgr); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/LocalResponseNormalization.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/LocalResponseNormalization.java index 3191d1dda..fe482ad62 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/LocalResponseNormalization.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/LocalResponseNormalization.java @@ -144,7 +144,7 @@ public class LocalResponseNormalization } catch (ND4JOpProfilerException e){ throw e; //NaN panic etc for debugging } catch (Throwable t){ - if(t.getMessage().contains("Failed to allocate")){ + if(t.getMessage() != null && t.getMessage().contains("Failed to allocate")){ //This is a memory exception - don't fallback to built-in implementation throw t; } @@ -211,7 +211,7 @@ public class LocalResponseNormalization } catch (ND4JOpProfilerException e){ throw e; //NaN panic etc for debugging } catch (Throwable t){ - if(t.getMessage().contains("Failed to allocate")){ + if(t.getMessage() != null && t.getMessage().contains("Failed to allocate")){ //This is a memory exception - don't fallback to built-in implementation throw t; } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LSTM.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LSTM.java index c3e77bb99..692713f6e 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LSTM.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LSTM.java @@ -22,6 +22,8 @@ import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.LayerHelper; +import org.deeplearning4j.nn.layers.mkldnn.BaseMKLDNNHelper; +import org.deeplearning4j.nn.layers.mkldnn.MKLDNNLSTMHelper; import org.deeplearning4j.nn.params.LSTMParamInitializer; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; @@ -73,6 +75,16 @@ public class LSTM extends BaseRecurrentLayer - * The activation will most likely be freed later, use detach() if you need to save it.
+ * The activation will most likely be freed later, use dup() if you need to save it.
*
* Note that this method will be called when any activation becomes available, not just ones from {@link #requiredVariables(SameDiff)}
* It is guaranteed to be called for variables from requiredVariables().
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index 284c3e6ac..c79677c1e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -29,6 +29,7 @@ import org.nd4j.autodiff.listeners.*; import org.nd4j.autodiff.listeners.impl.HistoryListener; import org.nd4j.autodiff.listeners.records.History; import org.nd4j.autodiff.listeners.records.LossCurve; +import org.nd4j.autodiff.samediff.api.OutAndGrad; import org.nd4j.autodiff.samediff.config.BatchOutputConfig; import org.nd4j.autodiff.samediff.config.EvaluationConfig; import org.nd4j.autodiff.samediff.config.FitConfig; @@ -1642,7 +1643,13 @@ public class SameDiff extends SDBaseOps { Set requiredVars = new HashSet<>(); for (Listener l : activeListeners) { - requiredVars.addAll(l.requiredVariables(this).trainingVariables()); + ListenerVariables lv = l.requiredVariables(this); + if(lv != null) { + Set s = lv.trainingVariables(); + if(s != null) { + requiredVars.addAll(s); + } + } } List listenersWitHistory = new ArrayList<>(listeners); @@ -1661,6 +1668,10 @@ public class SameDiff extends SDBaseOps { TrainingSession ts = new TrainingSession(gradInstance); gradInstance.setTrainingConfig(trainingConfig); //In case any listeners want to use it + for(Listener l : activeListeners){ + l.operationStart(gradInstance, Operation.TRAINING); + } + Set paramsToTrain = new LinkedHashSet<>(); for(Variable v : variables.values()){ if(v.getVariable().getVariableType() == VariableType.VARIABLE){ @@ -1844,9 +1855,12 @@ public class SameDiff extends SDBaseOps { */ private void validateListenerActivations(List listeners, Operation op) { for (Listener l : listeners) { - for (String s : l.requiredVariables(this).requiredVariables(op)) { - if (!variables.containsKey(s)) { - Preconditions.checkState(false, "Listener %s requested variable %s that is not defined in this SameDiff graph", l, s); + ListenerVariables lv = l.requiredVariables(this); + if(lv != null) { + for (String s : lv.requiredVariables(op)) { + if (!variables.containsKey(s)) { + Preconditions.checkState(false, "Listener %s requested variable %s that is not defined in this SameDiff graph", l, s); + } } } } @@ -2151,31 +2165,20 @@ public class SameDiff extends SDBaseOps { if (hasListeners) { for (Listener l : activeListeners) { - requiredVars.addAll(l.requiredVariables(this).evaluationVariables()); + ListenerVariables v = l.requiredVariables(this); + if(v != null) { + requiredVars.addAll(v.evaluationVariables()); + } } } String[] requiredVarsArr = requiredVars.toArray(new String[0]); while (iterator.hasNext()) { - long dataStart = hasListeners ? System.currentTimeMillis() : 0; MultiDataSet ds = iterator.next(); - long dataEnd = hasListeners ? System.currentTimeMillis() : 0; Map placeholderMap = toPlaceholderMap(ds); - Map m; - Map outs = null; - if (hasListeners) { - - for (Listener l : activeListeners) { - l.iterationStart(this, at, ds, (dataEnd - dataStart)); - } - - m = directExecHelper(placeholderMap, at, ds, Collections.emptyList(), activeListeners, requiredVarsArr); - } else { - m = directExecHelper(placeholderMap, at, ds, Collections.emptyList(), activeListeners, requiredVarsArr); - } - + Map m = directExecHelper(placeholderMap, at, ds, Collections.emptyList(), activeListeners, requiredVarsArr); for (Map.Entry> e : variableEvals.entrySet()) { INDArray prediction = m.get(e.getKey()); @@ -2188,15 +2191,6 @@ public class SameDiff extends SDBaseOps { } } - if (hasListeners) { - for (Listener l : activeListeners) { - Map outVars = Maps.newHashMap( - Maps.filterKeys(outs, - Predicates.in(l.requiredVariables(this).evaluationVariables()))); - l.iterationDone(this, at, ds, null); - } - } - at.setIteration(at.iteration() + 1); } @@ -2518,7 +2512,7 @@ public class SameDiff extends SDBaseOps { * Special case of {@link #batchOutput()}. */ public Map output(Map placeholders, @NonNull List outputs) { - return batchOutput().output(outputs.toArray(new String[0])).inputs(placeholders).exec(); + return batchOutput().output(outputs.toArray(new String[0])).inputs(placeholders).output(); } /** @@ -2529,7 +2523,7 @@ public class SameDiff extends SDBaseOps { * Special case of {@link #batchOutput()}. */ public Map output(Map placeholders, String... outputs) { - return batchOutput().output(outputs).inputs(placeholders).exec(); + return batchOutput().output(outputs).inputs(placeholders).output(); } @@ -2542,31 +2536,36 @@ public class SameDiff extends SDBaseOps { * @param listeners Additional listeners to use during this operation. * @param outputs The variables to output and return. */ - public Map output(Map placeholders, @NonNull List listeners, String... outputs) { - return batchOutputHelper(placeholders, listeners, outputs); + public Map output(Map placeholders, List listeners, String... outputs) { + return batchOutputHelper(placeholders, listeners, Operation.INFERENCE, outputs); } - protected Map batchOutputHelper(Map placeholders, @NonNull List listeners, String... outputs) { + protected Map batchOutputHelper(Map placeholders, List listeners, Operation operation, String... outputs) { List activeListeners = new ArrayList<>(); + if(operation == null) + operation = Operation.INFERENCE; + for (Listener l : this.listeners) - if (l.isActive(Operation.INFERENCE)) + if (l.isActive(operation)) activeListeners.add(l); - for (Listener l : listeners) - if (l.isActive(Operation.INFERENCE)) - activeListeners.add(l); - - for (Listener l : activeListeners) { - l.operationStart(this, Operation.INFERENCE); + if(listeners != null) { + for (Listener l : listeners) + if (l.isActive(operation)) + activeListeners.add(l); } - validateListenerActivations(activeListeners, Operation.INFERENCE); + for (Listener l : activeListeners) { + l.operationStart(this, operation); + } - Map ret = directExecHelper(placeholders, At.defaultAt(Operation.INFERENCE), null, Collections.emptyList(), activeListeners, outputs); + validateListenerActivations(activeListeners, operation); + + Map ret = directExecHelper(placeholders, At.defaultAt(operation), null, Collections.emptyList(), activeListeners, outputs); for (Listener l : activeListeners) { - l.operationEnd(this, Operation.INFERENCE); + l.operationEnd(this, operation); } return ret; } @@ -3992,7 +3991,6 @@ public class SameDiff extends SDBaseOps { sameDiffFunctionInstances.put(function, sub); } - } /** @@ -4012,32 +4010,64 @@ public class SameDiff extends SDBaseOps { */ public Map calculateGradients(Map placeholderVals, @NonNull Collection variables) { Preconditions.checkArgument(!variables.isEmpty(), "No variables were specified"); + OutAndGrad oag = calculateGradientsAndOutputs(placeholderVals, null, variables); + return oag.getGradients(); + } + + /** + * Calculate the activations and the gradients for the specified variables, in one execution call. + * This is equivalent to calling {@link #output(Map, List)} and {@link #calculateGradients(Map, Collection)}, but + * is more efficient than calling both separately. + * + * @param placeholderVals Placeholders. May be null + * @param outputVars Names of the variables that you want the activations/outputs for. May be null + * @param gradientVars Names of the variables that you want the gradient arrays for. May be null + * @return Activations and gradients, keyed by variable name + */ + public OutAndGrad calculateGradientsAndOutputs(Map placeholderVals, Collection outputVars, Collection gradientVars){ + Preconditions.checkArgument((outputVars != null && !outputVars.isEmpty()) || (gradientVars != null && !gradientVars.isEmpty()), + "No variables were specified for either output or gradients"); if (getFunction(GRAD_FN_KEY) == null) { createGradFunction(); } - List gradVarNames = new ArrayList<>(variables.size()); - for (String s : variables) { - Preconditions.checkState(this.variables.containsKey(s), "No variable with name \"%s\" exists in the SameDiff instance", s); - SDVariable v = getVariable(s).getGradient(); - if (v != null) { - //In a few cases (like loss not depending on trainable parameters) we won't have gradient array for parameter variable - gradVarNames.add(v.name()); + List varNames = new ArrayList<>(); + if(outputVars != null){ + varNames.addAll(outputVars); + } + if(gradientVars != null) { + for (String s : gradientVars) { + Preconditions.checkState(this.variables.containsKey(s), "No variable with name \"%s\" exists in the SameDiff instance", s); + SDVariable v = getVariable(s).getGradient(); + if (v != null) { + //In a few cases (like loss not depending on trainable parameters) we won't have gradient array for parameter variable + varNames.add(v.name()); + } } } //Key is gradient variable name - Map grads = getFunction(GRAD_FN_KEY).output(placeholderVals, gradVarNames); + SameDiff gradFn = getFunction(GRAD_FN_KEY); + gradFn.setListeners(listeners); + Map grads = gradFn.batchOutputHelper(placeholderVals, null, Operation.TRAINING, varNames.toArray(new String[0])); - Map out = new HashMap<>(); - for (String s : variables) { - if (getVariable(s).getGradient() != null) { - String gradVar = getVariable(s).getGradient().name(); - out.put(s, grads.get(gradVar)); + Map outOutputs = outputVars == null ? null : new HashMap(); + Map outGrads = gradientVars == null ? null : new HashMap(); + if(outputVars != null){ + for(String s : outputVars){ + outOutputs.put(s, grads.get(s)); + } + } + if(gradientVars != null) { + for (String s : gradientVars) { + if (getVariable(s).getGradient() != null) { + String gradVar = getVariable(s).getGradient().name(); + outGrads.put(s, grads.get(gradVar)); + } } } - return out; + return new OutAndGrad(outOutputs, outGrads); } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/api/OutAndGrad.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/api/OutAndGrad.java new file mode 100644 index 000000000..d0bc4b8b6 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/api/OutAndGrad.java @@ -0,0 +1,19 @@ +package org.nd4j.autodiff.samediff.api; + +import lombok.AllArgsConstructor; +import lombok.Data; +import org.nd4j.linalg.api.ndarray.INDArray; + +import java.util.Map; + +/** + * A simple object holding two maps - one of output arrays, another of gradient arrays + */ +@AllArgsConstructor +@Data +public class OutAndGrad { + + private final Map outputs; + private final Map gradients; + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java index fd89b4653..55165b530 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java @@ -797,7 +797,7 @@ public class InferenceSession extends AbstractSession { } else if (v.getVariableType() == VariableType.VARIABLE) { args[i] = v.getArr(); } else if (v.isPlaceHolder()) { - Preconditions.checkState(placeholderValues != null && placeholderValues.containsKey(s), "No array provided for placeholder %s", s); + Preconditions.checkState(placeholderValues != null && placeholderValues.containsKey(s), "No array was provided for required placeholder variable \"%s\"", s); args[i] = placeholderValues.get(s); } else { VarId vid = lookup(s, opInputs, allIterInputs, true); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/ImagePreProcessingScaler.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/ImagePreProcessingScaler.java index 598eac4f6..5fc4491c8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/ImagePreProcessingScaler.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/ImagePreProcessingScaler.java @@ -20,6 +20,7 @@ import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.Setter; import lombok.extern.slf4j.Slf4j; +import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; @@ -125,7 +126,9 @@ public class ImagePreProcessingScaler implements DataNormalization { @Override public void transformLabel(INDArray label) { - //No op + Preconditions.checkState(label != null && label.rank() == 4, "Labels can only be transformed for segmentation use" + + " cases using this preprocesser - i.e., labels must be rank 4. Got: %ndShape", label); + transform(label); } @Override @@ -161,7 +164,9 @@ public class ImagePreProcessingScaler implements DataNormalization { @Override public void revertLabels(INDArray labels) { - //No op + Preconditions.checkState(labels != null && labels.rank() == 4, "Labels can only be transformed for segmentation use" + + " cases using this preprocesser - i.e., labels must be rank 4. Got: %ndShape", labels); + revertFeatures(labels); } @Override @@ -171,9 +176,7 @@ public class ImagePreProcessingScaler implements DataNormalization { @Override public void fitLabel(boolean fitLabels) { - if (fitLabels) { - log.warn("Labels fitting not currently supported for ImagePreProcessingScaler. Labels will not be modified"); - } + //No-op } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index 1a6565ce4..25960a8a8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -5831,12 +5831,6 @@ public class Nd4j { } } case UBYTE: - UInt8Buffer b = new UInt8Buffer(ArrayUtil.prod(shapeOf)); - val sb = bb.order(_order).asReadOnlyBuffer(); - for (int e = 0; e < prod; e++) - b.put(e, sb.get(e)); - - return Nd4j.create(b, shapeOf); case BFLOAT16: case UINT16: INDArray arr = Nd4j.createUninitialized(_dtype, shapeOf); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index 5783909d8..8fe744b38 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -1000,8 +1000,13 @@ public class CudaExecutioner extends DefaultOpExecutioner { val dataType = op.resultType(); - val ret = Nd4j.createUninitialized(dataType, retShape); - op.setZ(ret); + if( op.z() == null ){ + val ret = Nd4j.createUninitialized(dataType, retShape); + op.setZ(ret); + } else if(op.z().dataType() != dataType || !Arrays.equals(retShape, op.z().shape())){ + throw new ND4JIllegalStateException("Output array for op " + op.getClass().getSimpleName() + " should have type " + dataType + " and shape " + Arrays.toString(retShape) + + " but has datatype " + op.z().dataType() + " and shape " + Arrays.toString(op.z().shape())); + } val eb = op.extraArgsDataBuff(op.z().dataType() == DataType.BOOL || op.getOpType() == Op.Type.REDUCE_LONG ? op.x().dataType() : op.z().dataType()); Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(eb, context) : null; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java index b67d82110..69c69388b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java @@ -39,6 +39,7 @@ import org.junit.Ignore; import org.junit.Test; import org.junit.rules.TemporaryFolder; import org.nd4j.OpValidationSuite; +import org.nd4j.autodiff.samediff.api.OutAndGrad; import org.nd4j.autodiff.samediff.impl.DefaultSameDiffConditional; import org.nd4j.autodiff.validation.OpValidation; import org.nd4j.autodiff.validation.TestCase; @@ -3426,4 +3427,30 @@ public class SameDiffTests extends BaseNd4jTest { INDArray a1 = rand1.eval(); assertEquals(a0, a1); } + + + @Test + public void testCalculateGradientsAndOutputs(){ + SameDiff sd = SameDiff.create(); + SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4); + SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 4, 3)); + SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 3)); + SDVariable z = in.mmul(w).add("z", b); + SDVariable softmax = sd.nn.softmax("softmax", z); + + Map ph = Collections.singletonMap("in", Nd4j.rand(DataType.FLOAT, 2, 4)); + List outputs = Arrays.asList("in", "z", "softmax"); + List grads = Arrays.asList("in", "w", "z"); + + OutAndGrad oag = sd.calculateGradientsAndOutputs(ph, outputs, grads); + Map outs = oag.getOutputs(); + Map g = oag.getGradients(); + + + Map outExp = sd.output(ph, outputs); + Map gExp = sd.calculateGradients(ph, grads); + + assertEquals(outExp, outs); + assertEquals(gExp, g); + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ListenerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ListenerTest.java index 8f1ab016f..e1123f5d8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ListenerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ListenerTest.java @@ -19,22 +19,28 @@ package org.nd4j.autodiff.samediff.listeners; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -import java.util.Arrays; -import java.util.List; +import java.util.*; +import lombok.NonNull; import org.junit.Test; -import org.nd4j.autodiff.listeners.Operation; +import org.nd4j.autodiff.listeners.*; import org.nd4j.autodiff.listeners.impl.ScoreListener; import org.nd4j.autodiff.listeners.records.History; +import org.nd4j.autodiff.listeners.records.LossCurve; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.TrainingConfig; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.autodiff.samediff.internal.Variable; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.Evaluation.Metric; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.IrisDataSetIterator; +import org.nd4j.linalg.dataset.adapter.SingletonDataSetIterator; +import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; import org.nd4j.linalg.factory.Nd4j; @@ -49,6 +55,11 @@ public class ListenerTest extends BaseNd4jTest { super(backend); } + @Override + public char ordering() { + return 'c'; + } + @Test public void irisHistoryTest() { @@ -112,8 +123,237 @@ public class ListenerTest extends BaseNd4jTest { assertTrue("Accuracy < 75%, was " + acc, acc >= 0.75); } - @Override - public char ordering() { - return 'c'; + @Test + public void testListenerCalls(){ + SameDiff sd = SameDiff.create(); + SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4); + SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 3); + SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 4, 3)); + SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 3)); + SDVariable z = in.mmul(w).add(b); + SDVariable softmax = sd.nn.softmax("softmax", z); + SDVariable loss = sd.loss.logLoss("loss" ,label, softmax); + + TestListener tl = new TestListener(Operation.INFERENCE); + sd.setListeners(tl); + + //Check listener called during inference + Map phMap = Collections.singletonMap("in", Nd4j.rand(1, 4)); + + for( int i=1; i<=5; i++ ) { + INDArray out = sd.outputSingle(phMap, "softmax"); + + assertEquals(0, tl.epochStartCount); + assertEquals(0, tl.epochEndCount); + assertEquals(0, tl.validationDoneCount); + assertEquals(0, tl.iterationStartCount); + assertEquals(0, tl.iterationDoneCount); + assertEquals(Collections.singletonMap(Operation.INFERENCE, i), tl.operationStartCount); + assertEquals(Collections.singletonMap(Operation.INFERENCE, i), tl.operationEndCount); + assertEquals(3*i, tl.preOpExecutionCount); //mmul, add, softmax + assertEquals(3*i, tl.opExecutionCount); + assertEquals(3*i, tl.activationAvailableCount); //mmul, add, softmax outputs + assertEquals(0, tl.preUpdateCount); //Inference -> no updating + } + + //Check listener NOT called during inference when set to Operation.TRAINING + tl = new TestListener(Operation.TRAINING); + sd.setListeners(tl); + sd.outputSingle(phMap, "softmax"); + + assertEquals(0, tl.epochStartCount); + assertEquals(0, tl.epochEndCount); + assertEquals(0, tl.validationDoneCount); + assertEquals(0, tl.iterationStartCount); + assertEquals(0, tl.iterationDoneCount); + assertEquals(Collections.emptyMap(), tl.operationStartCount); + assertEquals(Collections.emptyMap(), tl.operationEndCount); + assertEquals(0, tl.preOpExecutionCount); + assertEquals(0, tl.opExecutionCount); + assertEquals(0, tl.activationAvailableCount); + assertEquals(0, tl.preUpdateCount); + + //Check listener called during gradient calculation + tl = new TestListener(Operation.TRAINING); + sd.setListeners(tl); + phMap = new HashMap<>(); + phMap.put("in", Nd4j.rand( DataType.FLOAT, 1, 4)); + phMap.put("label", Nd4j.createFromArray(0f, 1f, 0f).reshape(1, 3)); + + for( int i=1; i<=3; i++ ) { + sd.calculateGradients(phMap, "in", "w", "b"); + assertEquals(0, tl.epochStartCount); + assertEquals(0, tl.epochEndCount); + assertEquals(0, tl.validationDoneCount); + assertEquals(0, tl.iterationStartCount); + assertEquals(0, tl.iterationDoneCount); + assertEquals(Collections.singletonMap(Operation.TRAINING, i), tl.operationStartCount); + assertEquals(Collections.singletonMap(Operation.TRAINING, i), tl.operationEndCount); + assertEquals(7*i, tl.preOpExecutionCount); //mmul, add, softmax, loss grad, softmax backward, add backward, mmul backward + assertEquals(7*i, tl.opExecutionCount); + assertEquals(11*i, tl.activationAvailableCount); //mmul, add, softmax, loss grad (weight, in, label), softmax bp, add backward (z, b), mmul (in, w) + assertEquals(0, tl.preUpdateCount); + } + + + //Check listener NOT called during gradient calculation - when listener is still set to INFERENCE mode + tl = new TestListener(Operation.INFERENCE); + sd.setListeners(tl); + for( int i=1; i<=3; i++ ) { + sd.calculateGradients(phMap, "in", "w", "b"); + assertEquals(0, tl.epochStartCount); + assertEquals(0, tl.epochEndCount); + assertEquals(0, tl.validationDoneCount); + assertEquals(0, tl.iterationStartCount); + assertEquals(0, tl.iterationDoneCount); + assertEquals(Collections.emptyMap(), tl.operationStartCount); + assertEquals(Collections.emptyMap(), tl.operationEndCount); + assertEquals(0, tl.preOpExecutionCount); + assertEquals(0, tl.opExecutionCount); + assertEquals(0, tl.activationAvailableCount); + assertEquals(0, tl.preUpdateCount); + } + + //Check fit: + tl = new TestListener(Operation.TRAINING); + sd.setListeners(tl); + sd.setTrainingConfig(TrainingConfig.builder() + .dataSetFeatureMapping("in") + .dataSetLabelMapping("label") + .updater(new Adam(1e-3)) + .build()); + + SingletonDataSetIterator dsi = new SingletonDataSetIterator(new DataSet(phMap.get("in"), phMap.get("label"))); + for( int i=1; i<=3; i++ ) { + sd.fit(dsi, 1); + assertEquals(i, tl.epochStartCount); + assertEquals(i, tl.epochEndCount); + assertEquals(0, tl.validationDoneCount); + assertEquals(i, tl.iterationStartCount); + assertEquals(i, tl.iterationDoneCount); + assertEquals(Collections.singletonMap(Operation.TRAINING, i), tl.operationStartCount); + assertEquals(Collections.singletonMap(Operation.TRAINING, i), tl.operationEndCount); + assertEquals(7*i, tl.preOpExecutionCount); //mmul, add, softmax, loss grad, softmax backward, add backward, mmul backward + assertEquals(7*i, tl.opExecutionCount); + assertEquals(11*i, tl.activationAvailableCount); //mmul, add, softmax, loss grad (weight, in, label), softmax bp, add backward (z, b), mmul (in, w) + assertEquals(2*i, tl.preUpdateCount); //w, b + } + + + //Check evaluation: + tl = new TestListener(Operation.EVALUATION); + sd.setListeners(tl); + + for( int i=1; i<=3; i++ ) { + sd.evaluate(dsi, "softmax", new Evaluation()); + assertEquals(0, tl.epochStartCount); + assertEquals(0, tl.epochEndCount); + assertEquals(0, tl.validationDoneCount); + assertEquals(0, tl.iterationStartCount); + assertEquals(0, tl.iterationDoneCount); + assertEquals(Collections.singletonMap(Operation.EVALUATION, i), tl.operationStartCount); + assertEquals(Collections.singletonMap(Operation.EVALUATION, i), tl.operationEndCount); + assertEquals(3*i, tl.preOpExecutionCount); //mmul, add, softmax + assertEquals(3*i, tl.opExecutionCount); + assertEquals(3*i, tl.activationAvailableCount); //mmul, add, softmax + assertEquals(0, tl.preUpdateCount); //w, b + } + } + + private static class TestListener implements Listener { + + public TestListener(Operation operation){ + this.operation = operation; + } + + private final Operation operation; + + private int epochStartCount = 0; + private int epochEndCount = 0; + private int validationDoneCount = 0; + private int iterationStartCount = 0; + private int iterationDoneCount = 0; + private Map operationStartCount = new HashMap<>(); + private Map operationEndCount = new HashMap<>(); + private int preOpExecutionCount = 0; + private int opExecutionCount = 0; + private int activationAvailableCount = 0; + private int preUpdateCount = 0; + + + @Override + public ListenerVariables requiredVariables(SameDiff sd) { + return null; + } + + @Override + public boolean isActive(Operation operation) { + return this.operation == null || this.operation == operation; + } + + @Override + public void epochStart(SameDiff sd, At at) { + epochStartCount++; + } + + @Override + public ListenerResponse epochEnd(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis) { + epochEndCount++; + return ListenerResponse.CONTINUE; + } + + @Override + public ListenerResponse validationDone(SameDiff sd, At at, long validationTimeMillis) { + validationDoneCount++; + return ListenerResponse.CONTINUE; + } + + @Override + public void iterationStart(SameDiff sd, At at, MultiDataSet data, long etlTimeMs) { + iterationStartCount++; + } + + @Override + public void iterationDone(SameDiff sd, At at, MultiDataSet dataSet, Loss loss) { + iterationDoneCount++; + } + + @Override + public void operationStart(SameDiff sd, Operation op) { + if(!operationStartCount.containsKey(op)) { + operationStartCount.put(op, 1); + } else { + operationStartCount.put(op, operationStartCount.get(op) + 1); + } + } + + @Override + public void operationEnd(SameDiff sd, Operation op) { + if(!operationEndCount.containsKey(op)) { + operationEndCount.put(op, 1); + } else { + operationEndCount.put(op, operationEndCount.get(op) + 1); + } + } + + @Override + public void preOpExecution(SameDiff sd, At at, SameDiffOp op) { + preOpExecutionCount++; + } + + @Override + public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { + opExecutionCount++; + } + + @Override + public void activationAvailable(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, String varName, INDArray activation) { + activationAvailableCount++; + } + + @Override + public void preUpdate(SameDiff sd, At at, Variable v, INDArray update) { + preUpdateCount++; + } } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/ImagePreProcessortTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/ImagePreProcessortTest.java index 1f8547a27..6dfe1e4c7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/ImagePreProcessortTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/ImagePreProcessortTest.java @@ -20,11 +20,13 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.preprocessor.ImageMultiPreProcessingScaler; import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; +import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.ops.transforms.Transforms; import static org.junit.Assert.assertEquals; @@ -42,32 +44,32 @@ public class ImagePreProcessortTest extends BaseNd4jTest { @Test public void simpleImageTest() { - INDArray rChannels = Nd4j.zeros(10, 10).addi(128); - INDArray gChannels = Nd4j.zeros(10, 10).addi(64); - INDArray bChannels = Nd4j.zeros(10, 10).addi(255); - INDArray image = Nd4j.vstack(rChannels, gChannels, bChannels).reshape(3, 10, 10); + INDArray rChannels = Nd4j.zeros(DataType.FLOAT, 10, 10).addi(128); + INDArray gChannels = Nd4j.zeros(DataType.FLOAT, 10, 10).addi(64); + INDArray bChannels = Nd4j.zeros(DataType.FLOAT, 10, 10).addi(255); + INDArray image = Nd4j.vstack(rChannels, gChannels, bChannels).reshape(1, 3, 10, 10); INDArray orig = image.dup(); //System.out.println(Arrays.toString(image.shape())); - DataSet ds = new DataSet(image.reshape(1, 3, 10, 10), Nd4j.ones(1, 1)); + DataSet ds = new DataSet(image, Nd4j.ones(1, 1)); ImagePreProcessingScaler myScaler = new ImagePreProcessingScaler(); //So this should scale to 0.5,0.25 and 1; INDArray expected = image.mul(0); - expected.slice(0, 0).addi(0.5); - expected.slice(1, 0).addi(0.25); - expected.slice(2, 0).addi(1.0); + expected.slice(0, 1).addi(0.5); + expected.slice(1, 1).addi(0.25); + expected.slice(2, 1).addi(1.0); myScaler.transform(ds); assertTrue(Transforms.abs(ds.getFeatures().sub(expected)).maxNumber().doubleValue() <= 0.01); //Now giving it 16 bits instead of the default //System.out.println(Arrays.toString(image.shape())); - ds = new DataSet(image.reshape(1, 3, 10, 10), Nd4j.ones(1, 1)); + ds = new DataSet(image, Nd4j.ones(1, 1)); myScaler = new ImagePreProcessingScaler(0, 1, 16); //So this should scale to 0.5,0.25 and 1; expected = image.mul(0); - expected.slice(0, 0).addi(0.5 / 256); - expected.slice(1, 0).addi(0.25 / 256); - expected.slice(2, 0).addi(1.0 / 256); + expected.slice(0, 1).addi(0.5 / 256); + expected.slice(1, 1).addi(0.25 / 256); + expected.slice(2, 1).addi(1.0 / 256); myScaler.transform(ds); assertTrue(Transforms.abs(ds.getFeatures().sub(expected)).maxNumber().doubleValue() <= 0.01); @@ -88,6 +90,16 @@ public class ImagePreProcessortTest extends BaseNd4jTest { myScaler.transform(before); myScaler.revertFeatures(before); assertEquals(orig, before); + + + //Test labels (segmentation case) + before = orig.dup(); + myScaler = new ImagePreProcessingScaler(0, 1); + myScaler.transformLabel(before); + expected = orig.div(255); + assertEquals(expected, before); + myScaler.revertLabels(before); + assertEquals(orig, before); } @Test