Various SameDiff fixes (#21)

* MKLDNN LSTM forward implementation (disabled pending #8331)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8318 add SameDiff.calculateGradientsAndOutputs

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Disable mkldnn backprop for now - pending fix, issue #8335

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8337 Fix CudaExecutioner unnecessary result array allocation/replacement

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small FlatBuffers serde fix, UInt8

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8135 ImagePreProcessingScaler - add segmentation support

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8319 Ensure listeners are called when they are supposed to be called

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8214 UNet (non-pretrained) last conv layer kernal size fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-11-02 11:25:53 +11:00 committed by GitHub
parent b816845797
commit d82877b18b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 827 additions and 119 deletions

View File

@ -0,0 +1,107 @@
package org.deeplearning4j;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.mkldnn.MKLDNNBatchNormHelper;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import java.lang.reflect.Field;
import static junit.framework.TestCase.*;
public class TestBatchNormBp {
@Test
public void test(){
Nd4j.getRandom().setSeed(12345);
// INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 4, 4);
INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 15, 15);
INDArray mean = in.mean(0, 2, 3); //Nd4j.rand(DataType.FLOAT, 3);
INDArray var = in.var(0, 2, 3); //Nd4j.rand(DataType.FLOAT, 3);
INDArray eps = Nd4j.rand(DataType.FLOAT, in.shape());
// INDArray gamma = Nd4j.ones(DataType.FLOAT, 3);
// INDArray beta = Nd4j.zeros(DataType.FLOAT, 3);
INDArray gamma = Nd4j.rand(DataType.FLOAT, 3);
INDArray beta = Nd4j.rand(DataType.FLOAT, 3);
double e = 1e-5;
INDArray dLdIn = in.ulike();
INDArray dLdm = mean.ulike();
INDArray dLdv = var.ulike();
INDArray dLdg = gamma.ulike();
INDArray dLdb = beta.ulike();
DynamicCustomOp op = DynamicCustomOp.builder("batchnorm_bp")
.addInputs(in, mean, var, eps, gamma, beta)
.addIntegerArguments(
1, //Apply scale
1, //Apply beta
1) //Axis (NCHW)
.addFloatingPointArguments(e)
.addOutputs(dLdIn, dLdm, dLdv, dLdg, dLdb)
.build();
Nd4j.exec(op);
System.out.println(dLdIn);
}
@Test
public void compareImpls() throws Exception {
Nd4j.getRandom().setSeed(12345);
INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 15, 15);
INDArray mean = in.mean(0, 2, 3).reshape(1,3);
INDArray var = in.var(0, 2, 3).reshape(1,3);
INDArray eps = Nd4j.rand(DataType.FLOAT, in.shape());
INDArray gamma = Nd4j.rand(DataType.FLOAT, 1,3);
INDArray beta = Nd4j.rand(DataType.FLOAT, 1,3);
double e = 1e-3;
INDArray dLdIn = in.ulike();
INDArray dLdm = mean.ulike();
INDArray dLdv = var.ulike();
INDArray dLdg = gamma.ulike();
INDArray dLdb = beta.ulike();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.inferenceWorkspaceMode(WorkspaceMode.NONE)
.trainingWorkspaceMode(WorkspaceMode.NONE)
.list()
.layer(new BatchNormalization.Builder().nIn(3).nOut(3).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
org.deeplearning4j.nn.layers.normalization.BatchNormalization bn = (org.deeplearning4j.nn.layers.normalization.BatchNormalization) net.getLayer(0);
assertNotNull(bn.getHelper());
Field f = bn.getClass().getDeclaredField("helper");
f.setAccessible(true);
f.set(bn, null);
assertNull(bn.getHelper());
MKLDNNBatchNormHelper h = new MKLDNNBatchNormHelper(DataType.FLOAT);
net.output(in, true);
bn.setInput(in, LayerWorkspaceMgr.noWorkspaces());
Pair<Gradient,INDArray> p = net.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces());
h.preOutput(in, true, new int[]{1,3}, gamma, beta, mean, var, 0.5, e, LayerWorkspaceMgr.noWorkspaces());
Pair<Gradient,INDArray> pmkl = h.backpropGradient(in, eps, new int[]{1,3}, gamma, beta, dLdg, dLdb, e, LayerWorkspaceMgr.noWorkspaces());
INDArray dldin_dl4j = p.getSecond();
System.out.println("dl4j == mkldnn: " + p.getSecond().equals(pmkl.getSecond()));
}
}

View File

@ -23,10 +23,13 @@ import org.deeplearning4j.datasets.iterator.impl.SingletonDataSetIterator;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
@ -36,10 +39,13 @@ import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.primitives.Pair;
import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.Collections;
import static junit.framework.TestCase.*;
import static org.junit.Assume.assumeTrue;
public class ValidateMKLDNN extends BaseDL4JTest {
@ -148,7 +154,7 @@ public class ValidateMKLDNN extends BaseDL4JTest {
.padding(0, 0)
.nOut(3)
.build())
.layer(new BatchNormalization.Builder().cudnnAllowFallback(false).build())
.layer(new BatchNormalization.Builder().helperAllowFallback(false)/*.eps(0)*/.build())
.layer(new ConvolutionLayer.Builder().activation(Activation.TANH)
.kernelSize(kernel)
.stride(stride)
@ -256,4 +262,54 @@ public class ValidateMKLDNN extends BaseDL4JTest {
}
}
}
@Test
public void compareBatchNormBackward() throws Exception {
Nd4j.getRandom().setSeed(12345);
INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 15, 15);
INDArray mean = in.mean(0, 2, 3).reshape(1,3);
INDArray var = in.var(0, 2, 3).reshape(1,3);
INDArray eps = Nd4j.rand(DataType.FLOAT, in.shape());
INDArray gamma = Nd4j.rand(DataType.FLOAT, 1,3);
INDArray beta = Nd4j.rand(DataType.FLOAT, 1,3);
double e = 1e-3;
INDArray dLdIn = in.ulike();
INDArray dLdm = mean.ulike();
INDArray dLdv = var.ulike();
INDArray dLdg = gamma.ulike();
INDArray dLdb = beta.ulike();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.inferenceWorkspaceMode(WorkspaceMode.NONE)
.trainingWorkspaceMode(WorkspaceMode.NONE)
.list()
.layer(new BatchNormalization.Builder().nIn(3).nOut(3).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
org.deeplearning4j.nn.layers.normalization.BatchNormalization bn = (org.deeplearning4j.nn.layers.normalization.BatchNormalization) net.getLayer(0);
assertNotNull(bn.getHelper());
System.out.println(bn.getHelper());
net.output(in, true);
bn.setInput(in, LayerWorkspaceMgr.noWorkspaces());
Pair<Gradient,INDArray> pcudnn = net.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces());
Field f = bn.getClass().getDeclaredField("helper");
f.setAccessible(true);
f.set(bn, null);
assertNull(bn.getHelper());
net.output(in, true);
bn.setInput(in, LayerWorkspaceMgr.noWorkspaces());
Pair<Gradient,INDArray> p = net.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces());
INDArray dldin_dl4j = p.getSecond();
INDArray dldin_helper = pcudnn.getSecond();
assertTrue(dldin_dl4j.equalsWithEps(dldin_helper, 1e-5));
}
}

View File

@ -123,7 +123,7 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
}
@Override
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, int[] shape, INDArray gamma,
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, int[] shape, INDArray gamma, INDArray beta,
INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr layerWorkspaceMgr) {
this.eps = eps;
val miniBatch = (int) input.size(0);
@ -189,7 +189,7 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
Pointer varCacheData = allocator.getPointer(varCache, context);
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
checkCudnn(cudnnBatchNormalizationBackward(cudnnContext, batchNormMode, alpha, beta, alpha, alpha,
checkCudnn(cudnnBatchNormalizationBackward(cudnnContext, batchNormMode, alpha, this.beta, alpha, alpha,
cudnnContext.srcTensorDesc, srcData, cudnnContext.deltaTensorDesc, epsData,
cudnnContext.dstTensorDesc, dstData, cudnnContext.gammaBetaTensorDesc, gammaData, dGammaData,
dBetaData, eps, meanCacheData, varCacheData));

View File

@ -16,21 +16,28 @@
package org.deeplearning4j.nn.layers.mkldnn;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper;
import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm;
import org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNormDerivative;
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
/**
@ -57,27 +64,53 @@ public class MKLDNNBatchNormHelper implements BatchNormalizationHelper {
@Override
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma,
INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr workspaceMgr) {
//2019-02-14: Backprop disabled pending fixes. https://github.com/deeplearning4j/deeplearning4j/issues/7166
//Also no MKL-DNN implemented for backprop anyway
INDArray beta, INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr workspaceMgr) {
if(input.dataType() != DataType.FLOAT)
return null; //MKL-DNN only supports float
/*
INDArray[] in = gamma == null ? new INDArray[]{input, mean, var, epsilon} : new INDArray[]{input, mean, var, gamma, beta, epsilon};
//TODO FIXME - AB 2019/11/01 - https://github.com/eclipse/deeplearning4j/issues/8335
List<INDArray> args = new ArrayList<>();
args.add(input);
args.add(meanCache);
args.add(varCache);
args.add(epsilon);
if(gamma != null)
args.add(gamma.reshape(gamma.length()));
if(beta != null)
args.add(beta.reshape(beta.length()));
INDArray gradAtInput = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), input.shape());
INDArray[] out = gamma == null ? new INDArray[]{gradAtInput, }
BatchNormDerivative bn = BatchNormDerivative.derivativeBuilder()
.applyBeta(gamma != null)
.applyGamma(gamma != null)
.axis(new int[]{1}) //4d: is channels: NCHW; 2d: is nIn - axis 1 in both cases
.epsilon(eps)
.inputArrays(in)
.outputArrays(new INDArray[]{out})
DynamicCustomOp op = DynamicCustomOp.builder("batchnorm_bp")
.addInputs(args.toArray(new INDArray[0]))
.addIntegerArguments(
gamma == null ? 0 : 1, //Apply scale
beta == null ? 0 : 1, //Apply beta
1) //Axis (NCHW)
.addFloatingPointArguments(eps)
.build();
Nd4j.exec(bn);
*/
INDArray epsAtInput = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape());
INDArray dLdm = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, meanCache.dataType(), meanCache.shape());
INDArray dLdv = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, meanCache.dataType(), meanCache.shape());
op.setOutputArgument(0, epsAtInput);
op.setOutputArgument(1, dLdm);
op.setOutputArgument(2, dLdv);
if(dGammaView != null) {
//Both are always null/not null simultaneously
op.setOutputArgument(3, dGammaView.reshape(dGammaView.length()));
op.setOutputArgument(4, dBetaView.reshape(dBetaView.length()));
}
Nd4j.exec(op);
Gradient g = new DefaultGradient();
g.setGradientFor(BatchNormalizationParamInitializer.GAMMA, dGammaView);
g.setGradientFor(BatchNormalizationParamInitializer.BETA, dBetaView);
return new Pair<>(g, epsAtInput);
*/
return null;
}

View File

@ -0,0 +1,168 @@
package org.deeplearning4j.nn.layers.mkldnn;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.recurrent.FwdPassReturn;
import org.deeplearning4j.nn.layers.recurrent.LSTMHelper;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.*;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.primitives.Pair;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
public class MKLDNNLSTMHelper implements LSTMHelper {
@Override
public boolean checkSupported(IActivation gateActivationFn, IActivation activationFn, boolean hasPeepholeConnections) {
//TODO check other activation functions for MKLDNN
return gateActivationFn instanceof ActivationSigmoid && activationFn instanceof ActivationTanH && BaseMKLDNNHelper.mklDnnEnabled();
}
@Override
public Pair<Gradient, INDArray> backpropGradient(NeuralNetConfiguration conf, IActivation gateActivationFn, INDArray input,
INDArray recurrentWeights, INDArray inputWeights, INDArray epsilon, boolean truncatedBPTT,
int tbpttBackwardLength, FwdPassReturn fwdPass, boolean forwards, String inputWeightKey,
String recurrentWeightKey, String biasWeightKey, Map<String, INDArray> gradientViews,
INDArray maskArray, boolean hasPeepholeConnections, LayerWorkspaceMgr workspaceMgr) {
//Not yet implemented/supported
return null;
}
@Override
public FwdPassReturn activate(Layer layer, NeuralNetConfiguration conf, IActivation gateActivationFn, INDArray input,
INDArray recurrentWeights, INDArray inputWeights, INDArray biases, boolean training,
INDArray prevOutputActivations, INDArray prevMemCellState, boolean forBackprop, boolean forwards,
String inputWeightKey, INDArray maskArray, boolean hasPeepholeConnections, LayerWorkspaceMgr workspaceMgr) {
/*
DL4J data format: [bS, nIn, sL] - dataFormat == 2, directionMode == 0 (forward)
Inputs:
x = [bS, nIn, sL]
Wx = [nIn, 4*nOut]
Wr = [nOut, 4*nOut]
Wp = [3*nOut] Optional peephole weights
b = [4*nOut]
seqLen = [bS]
initialOut = [bs, nOut]
initialCell = [bs, nOut]
Outputs:
out = [bS, nOut, sL]
outLast = [bs, nOut]
cellLast = [bs,nOut]
Gates order: input, forget, input modulation, output
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided
const auto hasInitH = B_ARG(2); // indicates whether initial output is provided
const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided
const auto hasPH = B_ARG(4); // indicates whether peephole connections are present
const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1}
const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
*/
INDArray b1d = biases.reshape(biases.length());
INDArray seqLen = null;
if(maskArray != null){
seqLen = BooleanIndexing.firstIndex(maskArray, Conditions.equals(0), 1); //First 0 along dimension 1 (for [mb, seqLen])
}
List<INDArray> args = new ArrayList<>();
args.add(input);
args.add(inputWeights);
args.add(recurrentWeights);
if(hasPeepholeConnections){
throw new IllegalStateException("Not yet implemented");
}
args.add(b1d);
if(seqLen != null)
args.add(seqLen);
if(prevOutputActivations != null)
args.add(prevOutputActivations);
if(prevMemCellState != null)
args.add(prevMemCellState);
IActivation a = ((LSTM)conf.getLayer()).getActivationFn();
DynamicCustomOp op = DynamicCustomOp.builder("lstmLayer")
.addInputs(args.toArray(new INDArray[0]))
.addBooleanArguments(
true, //hasBiases
seqLen != null, //hasSeqLen
prevOutputActivations != null, //hasInitH
prevMemCellState != null, //hasInitC
hasPeepholeConnections, //hasPh
true, //retFullSeq
true, //retLastH
true //retLastC
)
.addIntegerArguments(
2, //data format: 2 = [bS, nIn, sL]
0, //direction: 0 = forward
activationToArg(gateActivationFn), //Gate activation
activationToArg(a), //Cell state activation
activationToArg(a) //Output activation (same as cell in DL4J)
)
.build();
List<LongShapeDescriptor> outShapes = op.calculateOutputShape();
for(LongShapeDescriptor lsd : outShapes){
INDArray arr = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, lsd.dataType(), lsd.getShape(), lsd.getOrder());
op.addOutputArgument(arr);
}
FwdPassReturn f = new FwdPassReturn();
f.fwdPassOutput = op.getOutputArgument(0);
f.lastAct = op.getOutputArgument(1);
f.lastMemCell = op.getOutputArgument(2);
return f;
}
@Override
public Map<String, Long> helperMemoryUse() {
return Collections.emptyMap();
}
private int activationToArg(IActivation a){
//0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
if(a instanceof ActivationTanH)
return 0;
if(a instanceof ActivationReLU)
return 1;
if(a instanceof ActivationSigmoid)
return 2;
if(a instanceof ActivationIdentity)
return 3;
if(a instanceof ActivationLReLU)
return 4;
if(a instanceof ActivationThresholdedReLU)
return 5;
if(a instanceof ActivationHardSigmoid)
return 7;
if(a instanceof ActivationELU)
return 8;
if(a instanceof ActivationSoftSign)
return 9;
if(a instanceof ActivationSoftPlus)
return 10;
throw new IllegalStateException("Unknown or not supported activation function: " + a);
}
}

View File

@ -118,6 +118,7 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
INDArray globalVar = params.get(BatchNormalizationParamInitializer.GLOBAL_VAR); //One of log10std will be null depending on config
INDArray globalLog10Std = params.get(BatchNormalizationParamInitializer.GLOBAL_LOG_STD);
INDArray gamma = null;
INDArray beta = null;
INDArray dGammaView;
INDArray dBetaView;
INDArray dGlobalMeanView = gradientViews.get(BatchNormalizationParamInitializer.GLOBAL_MEAN);
@ -129,6 +130,7 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
dBetaView = Nd4j.createUninitialized(dataType, tempShape, 'c');
} else {
gamma = getParam(BatchNormalizationParamInitializer.GAMMA);
beta = getParam(BatchNormalizationParamInitializer.BETA);
dGammaView = gradientViews.get(BatchNormalizationParamInitializer.GAMMA);
dBetaView = gradientViews.get(BatchNormalizationParamInitializer.BETA);
}
@ -154,12 +156,12 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
Pair<Gradient,INDArray> ret = null;
try {
ret = helper.backpropGradient(in, eps, shape, gamma, dGammaView, dBetaView,
ret = helper.backpropGradient(in, eps, shape, gamma, beta, dGammaView, dBetaView,
layerConf.getEps(), workspaceMgr);
} catch (ND4JOpProfilerException e){
throw e; //NaN panic etc for debugging
} catch (Throwable t){
if(t.getMessage().contains("Failed to allocate")){
if(t.getMessage() != null && t.getMessage().contains("Failed to allocate")){
//This is a memory exception - don't fallback to built-in implementation
throw t;
}
@ -451,7 +453,7 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
} catch (ND4JOpProfilerException e){
throw e; //NaN panic etc for debugging
} catch (Throwable t) {
if(t.getMessage().contains("Failed to allocate")){
if(t.getMessage() != null && t.getMessage().contains("Failed to allocate")){
//This is a memory exception - don't fallback to built-in implementation
throw t;
}

View File

@ -31,8 +31,8 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
public interface BatchNormalizationHelper extends LayerHelper {
boolean checkSupported(double eps, boolean fixedGammaBeta);
Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma,
INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr workspaceMgr);
Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, INDArray beta,
INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr workspaceMgr);
INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean,
INDArray var, double decay, double eps, LayerWorkspaceMgr workspaceMgr);

View File

@ -144,7 +144,7 @@ public class LocalResponseNormalization
} catch (ND4JOpProfilerException e){
throw e; //NaN panic etc for debugging
} catch (Throwable t){
if(t.getMessage().contains("Failed to allocate")){
if(t.getMessage() != null && t.getMessage().contains("Failed to allocate")){
//This is a memory exception - don't fallback to built-in implementation
throw t;
}
@ -211,7 +211,7 @@ public class LocalResponseNormalization
} catch (ND4JOpProfilerException e){
throw e; //NaN panic etc for debugging
} catch (Throwable t){
if(t.getMessage().contains("Failed to allocate")){
if(t.getMessage() != null && t.getMessage().contains("Failed to allocate")){
//This is a memory exception - don't fallback to built-in implementation
throw t;
}

View File

@ -22,6 +22,8 @@ import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.LayerHelper;
import org.deeplearning4j.nn.layers.mkldnn.BaseMKLDNNHelper;
import org.deeplearning4j.nn.layers.mkldnn.MKLDNNLSTMHelper;
import org.deeplearning4j.nn.params.LSTMParamInitializer;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
@ -73,6 +75,16 @@ public class LSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.L
}
}
}
/*
//Disabled pending: https://github.com/eclipse/deeplearning4j/issues/8331
else if ("CPU".equalsIgnoreCase(backend) && BaseMKLDNNHelper.mklDnnEnabled()){
helper = new MKLDNNLSTMHelper();
log.debug("MKLDNNLSTMHelper successfully initialized");
if (!helper.checkSupported(layerConf().getGateActivationFn(), layerConf().getActivationFn(), false)) {
helper = null;
}
}
*/
}
@Override

View File

@ -215,7 +215,7 @@ public class UNet extends ZooModel {
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "conv9-2")
.addLayer("conv10", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(1)
.addLayer("conv10", new ConvolutionLayer.Builder(1,1).stride(1,1).nOut(1)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.IDENTITY).build(), "conv9-3")
.addLayer("output", new CnnLossLayer.Builder(LossFunctions.LossFunction.XENT)

View File

@ -122,7 +122,7 @@ public interface Listener {
/**
* Called when any activation becomes available.
* <p>
* The activation will most likely be freed later, use detach() if you need to save it.<br>
* The activation will most likely be freed later, use dup() if you need to save it.<br>
* <br>
* Note that this method will be called when any activation becomes available, not just ones from {@link #requiredVariables(SameDiff)}<br>
* It is guaranteed to be called for variables from requiredVariables().<br>

View File

@ -29,6 +29,7 @@ import org.nd4j.autodiff.listeners.*;
import org.nd4j.autodiff.listeners.impl.HistoryListener;
import org.nd4j.autodiff.listeners.records.History;
import org.nd4j.autodiff.listeners.records.LossCurve;
import org.nd4j.autodiff.samediff.api.OutAndGrad;
import org.nd4j.autodiff.samediff.config.BatchOutputConfig;
import org.nd4j.autodiff.samediff.config.EvaluationConfig;
import org.nd4j.autodiff.samediff.config.FitConfig;
@ -1642,7 +1643,13 @@ public class SameDiff extends SDBaseOps {
Set<String> requiredVars = new HashSet<>();
for (Listener l : activeListeners) {
requiredVars.addAll(l.requiredVariables(this).trainingVariables());
ListenerVariables lv = l.requiredVariables(this);
if(lv != null) {
Set<String> s = lv.trainingVariables();
if(s != null) {
requiredVars.addAll(s);
}
}
}
List<Listener> listenersWitHistory = new ArrayList<>(listeners);
@ -1661,6 +1668,10 @@ public class SameDiff extends SDBaseOps {
TrainingSession ts = new TrainingSession(gradInstance);
gradInstance.setTrainingConfig(trainingConfig); //In case any listeners want to use it
for(Listener l : activeListeners){
l.operationStart(gradInstance, Operation.TRAINING);
}
Set<String> paramsToTrain = new LinkedHashSet<>();
for(Variable v : variables.values()){
if(v.getVariable().getVariableType() == VariableType.VARIABLE){
@ -1844,9 +1855,12 @@ public class SameDiff extends SDBaseOps {
*/
private void validateListenerActivations(List<Listener> listeners, Operation op) {
for (Listener l : listeners) {
for (String s : l.requiredVariables(this).requiredVariables(op)) {
if (!variables.containsKey(s)) {
Preconditions.checkState(false, "Listener %s requested variable %s that is not defined in this SameDiff graph", l, s);
ListenerVariables lv = l.requiredVariables(this);
if(lv != null) {
for (String s : lv.requiredVariables(op)) {
if (!variables.containsKey(s)) {
Preconditions.checkState(false, "Listener %s requested variable %s that is not defined in this SameDiff graph", l, s);
}
}
}
}
@ -2151,31 +2165,20 @@ public class SameDiff extends SDBaseOps {
if (hasListeners) {
for (Listener l : activeListeners) {
requiredVars.addAll(l.requiredVariables(this).evaluationVariables());
ListenerVariables v = l.requiredVariables(this);
if(v != null) {
requiredVars.addAll(v.evaluationVariables());
}
}
}
String[] requiredVarsArr = requiredVars.toArray(new String[0]);
while (iterator.hasNext()) {
long dataStart = hasListeners ? System.currentTimeMillis() : 0;
MultiDataSet ds = iterator.next();
long dataEnd = hasListeners ? System.currentTimeMillis() : 0;
Map<String, INDArray> placeholderMap = toPlaceholderMap(ds);
Map<String, INDArray> m;
Map<String, INDArray> outs = null;
if (hasListeners) {
for (Listener l : activeListeners) {
l.iterationStart(this, at, ds, (dataEnd - dataStart));
}
m = directExecHelper(placeholderMap, at, ds, Collections.<String>emptyList(), activeListeners, requiredVarsArr);
} else {
m = directExecHelper(placeholderMap, at, ds, Collections.<String>emptyList(), activeListeners, requiredVarsArr);
}
Map<String, INDArray> m = directExecHelper(placeholderMap, at, ds, Collections.<String>emptyList(), activeListeners, requiredVarsArr);
for (Map.Entry<String, List<IEvaluation>> e : variableEvals.entrySet()) {
INDArray prediction = m.get(e.getKey());
@ -2188,15 +2191,6 @@ public class SameDiff extends SDBaseOps {
}
}
if (hasListeners) {
for (Listener l : activeListeners) {
Map<String, INDArray> outVars = Maps.newHashMap(
Maps.filterKeys(outs,
Predicates.in(l.requiredVariables(this).evaluationVariables())));
l.iterationDone(this, at, ds, null);
}
}
at.setIteration(at.iteration() + 1);
}
@ -2518,7 +2512,7 @@ public class SameDiff extends SDBaseOps {
* Special case of {@link #batchOutput()}.
*/
public Map<String, INDArray> output(Map<String, INDArray> placeholders, @NonNull List<String> outputs) {
return batchOutput().output(outputs.toArray(new String[0])).inputs(placeholders).exec();
return batchOutput().output(outputs.toArray(new String[0])).inputs(placeholders).output();
}
/**
@ -2529,7 +2523,7 @@ public class SameDiff extends SDBaseOps {
* Special case of {@link #batchOutput()}.
*/
public Map<String, INDArray> output(Map<String, INDArray> placeholders, String... outputs) {
return batchOutput().output(outputs).inputs(placeholders).exec();
return batchOutput().output(outputs).inputs(placeholders).output();
}
@ -2542,31 +2536,36 @@ public class SameDiff extends SDBaseOps {
* @param listeners Additional listeners to use during this operation.
* @param outputs The variables to output and return.
*/
public Map<String, INDArray> output(Map<String, INDArray> placeholders, @NonNull List<Listener> listeners, String... outputs) {
return batchOutputHelper(placeholders, listeners, outputs);
public Map<String, INDArray> output(Map<String, INDArray> placeholders, List<Listener> listeners, String... outputs) {
return batchOutputHelper(placeholders, listeners, Operation.INFERENCE, outputs);
}
protected Map<String, INDArray> batchOutputHelper(Map<String, INDArray> placeholders, @NonNull List<Listener> listeners, String... outputs) {
protected Map<String, INDArray> batchOutputHelper(Map<String, INDArray> placeholders, List<Listener> listeners, Operation operation, String... outputs) {
List<Listener> activeListeners = new ArrayList<>();
if(operation == null)
operation = Operation.INFERENCE;
for (Listener l : this.listeners)
if (l.isActive(Operation.INFERENCE))
if (l.isActive(operation))
activeListeners.add(l);
for (Listener l : listeners)
if (l.isActive(Operation.INFERENCE))
activeListeners.add(l);
for (Listener l : activeListeners) {
l.operationStart(this, Operation.INFERENCE);
if(listeners != null) {
for (Listener l : listeners)
if (l.isActive(operation))
activeListeners.add(l);
}
validateListenerActivations(activeListeners, Operation.INFERENCE);
for (Listener l : activeListeners) {
l.operationStart(this, operation);
}
Map<String, INDArray> ret = directExecHelper(placeholders, At.defaultAt(Operation.INFERENCE), null, Collections.<String>emptyList(), activeListeners, outputs);
validateListenerActivations(activeListeners, operation);
Map<String, INDArray> ret = directExecHelper(placeholders, At.defaultAt(operation), null, Collections.<String>emptyList(), activeListeners, outputs);
for (Listener l : activeListeners) {
l.operationEnd(this, Operation.INFERENCE);
l.operationEnd(this, operation);
}
return ret;
}
@ -3992,7 +3991,6 @@ public class SameDiff extends SDBaseOps {
sameDiffFunctionInstances.put(function, sub);
}
}
/**
@ -4012,32 +4010,64 @@ public class SameDiff extends SDBaseOps {
*/
public Map<String, INDArray> calculateGradients(Map<String, INDArray> placeholderVals, @NonNull Collection<String> variables) {
Preconditions.checkArgument(!variables.isEmpty(), "No variables were specified");
OutAndGrad oag = calculateGradientsAndOutputs(placeholderVals, null, variables);
return oag.getGradients();
}
/**
* Calculate the activations and the gradients for the specified variables, in one execution call.
* This is equivalent to calling {@link #output(Map, List)} and {@link #calculateGradients(Map, Collection)}, but
* is more efficient than calling both separately.
*
* @param placeholderVals Placeholders. May be null
* @param outputVars Names of the variables that you want the activations/outputs for. May be null
* @param gradientVars Names of the variables that you want the gradient arrays for. May be null
* @return Activations and gradients, keyed by variable name
*/
public OutAndGrad calculateGradientsAndOutputs(Map<String,INDArray> placeholderVals, Collection<String> outputVars, Collection<String> gradientVars){
Preconditions.checkArgument((outputVars != null && !outputVars.isEmpty()) || (gradientVars != null && !gradientVars.isEmpty()),
"No variables were specified for either output or gradients");
if (getFunction(GRAD_FN_KEY) == null) {
createGradFunction();
}
List<String> gradVarNames = new ArrayList<>(variables.size());
for (String s : variables) {
Preconditions.checkState(this.variables.containsKey(s), "No variable with name \"%s\" exists in the SameDiff instance", s);
SDVariable v = getVariable(s).getGradient();
if (v != null) {
//In a few cases (like loss not depending on trainable parameters) we won't have gradient array for parameter variable
gradVarNames.add(v.name());
List<String> varNames = new ArrayList<>();
if(outputVars != null){
varNames.addAll(outputVars);
}
if(gradientVars != null) {
for (String s : gradientVars) {
Preconditions.checkState(this.variables.containsKey(s), "No variable with name \"%s\" exists in the SameDiff instance", s);
SDVariable v = getVariable(s).getGradient();
if (v != null) {
//In a few cases (like loss not depending on trainable parameters) we won't have gradient array for parameter variable
varNames.add(v.name());
}
}
}
//Key is gradient variable name
Map<String, INDArray> grads = getFunction(GRAD_FN_KEY).output(placeholderVals, gradVarNames);
SameDiff gradFn = getFunction(GRAD_FN_KEY);
gradFn.setListeners(listeners);
Map<String, INDArray> grads = gradFn.batchOutputHelper(placeholderVals, null, Operation.TRAINING, varNames.toArray(new String[0]));
Map<String, INDArray> out = new HashMap<>();
for (String s : variables) {
if (getVariable(s).getGradient() != null) {
String gradVar = getVariable(s).getGradient().name();
out.put(s, grads.get(gradVar));
Map<String, INDArray> outOutputs = outputVars == null ? null : new HashMap<String,INDArray>();
Map<String, INDArray> outGrads = gradientVars == null ? null : new HashMap<String,INDArray>();
if(outputVars != null){
for(String s : outputVars){
outOutputs.put(s, grads.get(s));
}
}
if(gradientVars != null) {
for (String s : gradientVars) {
if (getVariable(s).getGradient() != null) {
String gradVar = getVariable(s).getGradient().name();
outGrads.put(s, grads.get(gradVar));
}
}
}
return out;
return new OutAndGrad(outOutputs, outGrads);
}
/**

View File

@ -0,0 +1,19 @@
package org.nd4j.autodiff.samediff.api;
import lombok.AllArgsConstructor;
import lombok.Data;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Map;
/**
* A simple object holding two maps - one of output arrays, another of gradient arrays
*/
@AllArgsConstructor
@Data
public class OutAndGrad {
private final Map<String, INDArray> outputs;
private final Map<String, INDArray> gradients;
}

View File

@ -797,7 +797,7 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
} else if (v.getVariableType() == VariableType.VARIABLE) {
args[i] = v.getArr();
} else if (v.isPlaceHolder()) {
Preconditions.checkState(placeholderValues != null && placeholderValues.containsKey(s), "No array provided for placeholder %s", s);
Preconditions.checkState(placeholderValues != null && placeholderValues.containsKey(s), "No array was provided for required placeholder variable \"%s\"", s);
args[i] = placeholderValues.get(s);
} else {
VarId vid = lookup(s, opInputs, allIterInputs, true);

View File

@ -20,6 +20,7 @@ import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
@ -125,7 +126,9 @@ public class ImagePreProcessingScaler implements DataNormalization {
@Override
public void transformLabel(INDArray label) {
//No op
Preconditions.checkState(label != null && label.rank() == 4, "Labels can only be transformed for segmentation use" +
" cases using this preprocesser - i.e., labels must be rank 4. Got: %ndShape", label);
transform(label);
}
@Override
@ -161,7 +164,9 @@ public class ImagePreProcessingScaler implements DataNormalization {
@Override
public void revertLabels(INDArray labels) {
//No op
Preconditions.checkState(labels != null && labels.rank() == 4, "Labels can only be transformed for segmentation use" +
" cases using this preprocesser - i.e., labels must be rank 4. Got: %ndShape", labels);
revertFeatures(labels);
}
@Override
@ -171,9 +176,7 @@ public class ImagePreProcessingScaler implements DataNormalization {
@Override
public void fitLabel(boolean fitLabels) {
if (fitLabels) {
log.warn("Labels fitting not currently supported for ImagePreProcessingScaler. Labels will not be modified");
}
//No-op
}
@Override

View File

@ -5831,12 +5831,6 @@ public class Nd4j {
}
}
case UBYTE:
UInt8Buffer b = new UInt8Buffer(ArrayUtil.prod(shapeOf));
val sb = bb.order(_order).asReadOnlyBuffer();
for (int e = 0; e < prod; e++)
b.put(e, sb.get(e));
return Nd4j.create(b, shapeOf);
case BFLOAT16:
case UINT16:
INDArray arr = Nd4j.createUninitialized(_dtype, shapeOf);

View File

@ -1000,8 +1000,13 @@ public class CudaExecutioner extends DefaultOpExecutioner {
val dataType = op.resultType();
val ret = Nd4j.createUninitialized(dataType, retShape);
op.setZ(ret);
if( op.z() == null ){
val ret = Nd4j.createUninitialized(dataType, retShape);
op.setZ(ret);
} else if(op.z().dataType() != dataType || !Arrays.equals(retShape, op.z().shape())){
throw new ND4JIllegalStateException("Output array for op " + op.getClass().getSimpleName() + " should have type " + dataType + " and shape " + Arrays.toString(retShape)
+ " but has datatype " + op.z().dataType() + " and shape " + Arrays.toString(op.z().shape()));
}
val eb = op.extraArgsDataBuff(op.z().dataType() == DataType.BOOL || op.getOpType() == Op.Type.REDUCE_LONG ? op.x().dataType() : op.z().dataType());
Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(eb, context) : null;

View File

@ -39,6 +39,7 @@ import org.junit.Ignore;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.OpValidationSuite;
import org.nd4j.autodiff.samediff.api.OutAndGrad;
import org.nd4j.autodiff.samediff.impl.DefaultSameDiffConditional;
import org.nd4j.autodiff.validation.OpValidation;
import org.nd4j.autodiff.validation.TestCase;
@ -3426,4 +3427,30 @@ public class SameDiffTests extends BaseNd4jTest {
INDArray a1 = rand1.eval();
assertEquals(a0, a1);
}
@Test
public void testCalculateGradientsAndOutputs(){
SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4);
SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 4, 3));
SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 3));
SDVariable z = in.mmul(w).add("z", b);
SDVariable softmax = sd.nn.softmax("softmax", z);
Map<String,INDArray> ph = Collections.singletonMap("in", Nd4j.rand(DataType.FLOAT, 2, 4));
List<String> outputs = Arrays.asList("in", "z", "softmax");
List<String> grads = Arrays.asList("in", "w", "z");
OutAndGrad oag = sd.calculateGradientsAndOutputs(ph, outputs, grads);
Map<String,INDArray> outs = oag.getOutputs();
Map<String,INDArray> g = oag.getGradients();
Map<String,INDArray> outExp = sd.output(ph, outputs);
Map<String,INDArray> gExp = sd.calculateGradients(ph, grads);
assertEquals(outExp, outs);
assertEquals(gExp, g);
}
}

View File

@ -19,22 +19,28 @@ package org.nd4j.autodiff.samediff.listeners;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import java.util.Arrays;
import java.util.List;
import java.util.*;
import lombok.NonNull;
import org.junit.Test;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.listeners.*;
import org.nd4j.autodiff.listeners.impl.ScoreListener;
import org.nd4j.autodiff.listeners.records.History;
import org.nd4j.autodiff.listeners.records.LossCurve;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.Evaluation.Metric;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.IrisDataSetIterator;
import org.nd4j.linalg.dataset.adapter.SingletonDataSetIterator;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.factory.Nd4j;
@ -49,6 +55,11 @@ public class ListenerTest extends BaseNd4jTest {
super(backend);
}
@Override
public char ordering() {
return 'c';
}
@Test
public void irisHistoryTest() {
@ -112,8 +123,237 @@ public class ListenerTest extends BaseNd4jTest {
assertTrue("Accuracy < 75%, was " + acc, acc >= 0.75);
}
@Override
public char ordering() {
return 'c';
@Test
public void testListenerCalls(){
SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4);
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 3);
SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 4, 3));
SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 3));
SDVariable z = in.mmul(w).add(b);
SDVariable softmax = sd.nn.softmax("softmax", z);
SDVariable loss = sd.loss.logLoss("loss" ,label, softmax);
TestListener tl = new TestListener(Operation.INFERENCE);
sd.setListeners(tl);
//Check listener called during inference
Map<String,INDArray> phMap = Collections.singletonMap("in", Nd4j.rand(1, 4));
for( int i=1; i<=5; i++ ) {
INDArray out = sd.outputSingle(phMap, "softmax");
assertEquals(0, tl.epochStartCount);
assertEquals(0, tl.epochEndCount);
assertEquals(0, tl.validationDoneCount);
assertEquals(0, tl.iterationStartCount);
assertEquals(0, tl.iterationDoneCount);
assertEquals(Collections.singletonMap(Operation.INFERENCE, i), tl.operationStartCount);
assertEquals(Collections.singletonMap(Operation.INFERENCE, i), tl.operationEndCount);
assertEquals(3*i, tl.preOpExecutionCount); //mmul, add, softmax
assertEquals(3*i, tl.opExecutionCount);
assertEquals(3*i, tl.activationAvailableCount); //mmul, add, softmax outputs
assertEquals(0, tl.preUpdateCount); //Inference -> no updating
}
//Check listener NOT called during inference when set to Operation.TRAINING
tl = new TestListener(Operation.TRAINING);
sd.setListeners(tl);
sd.outputSingle(phMap, "softmax");
assertEquals(0, tl.epochStartCount);
assertEquals(0, tl.epochEndCount);
assertEquals(0, tl.validationDoneCount);
assertEquals(0, tl.iterationStartCount);
assertEquals(0, tl.iterationDoneCount);
assertEquals(Collections.emptyMap(), tl.operationStartCount);
assertEquals(Collections.emptyMap(), tl.operationEndCount);
assertEquals(0, tl.preOpExecutionCount);
assertEquals(0, tl.opExecutionCount);
assertEquals(0, tl.activationAvailableCount);
assertEquals(0, tl.preUpdateCount);
//Check listener called during gradient calculation
tl = new TestListener(Operation.TRAINING);
sd.setListeners(tl);
phMap = new HashMap<>();
phMap.put("in", Nd4j.rand( DataType.FLOAT, 1, 4));
phMap.put("label", Nd4j.createFromArray(0f, 1f, 0f).reshape(1, 3));
for( int i=1; i<=3; i++ ) {
sd.calculateGradients(phMap, "in", "w", "b");
assertEquals(0, tl.epochStartCount);
assertEquals(0, tl.epochEndCount);
assertEquals(0, tl.validationDoneCount);
assertEquals(0, tl.iterationStartCount);
assertEquals(0, tl.iterationDoneCount);
assertEquals(Collections.singletonMap(Operation.TRAINING, i), tl.operationStartCount);
assertEquals(Collections.singletonMap(Operation.TRAINING, i), tl.operationEndCount);
assertEquals(7*i, tl.preOpExecutionCount); //mmul, add, softmax, loss grad, softmax backward, add backward, mmul backward
assertEquals(7*i, tl.opExecutionCount);
assertEquals(11*i, tl.activationAvailableCount); //mmul, add, softmax, loss grad (weight, in, label), softmax bp, add backward (z, b), mmul (in, w)
assertEquals(0, tl.preUpdateCount);
}
//Check listener NOT called during gradient calculation - when listener is still set to INFERENCE mode
tl = new TestListener(Operation.INFERENCE);
sd.setListeners(tl);
for( int i=1; i<=3; i++ ) {
sd.calculateGradients(phMap, "in", "w", "b");
assertEquals(0, tl.epochStartCount);
assertEquals(0, tl.epochEndCount);
assertEquals(0, tl.validationDoneCount);
assertEquals(0, tl.iterationStartCount);
assertEquals(0, tl.iterationDoneCount);
assertEquals(Collections.emptyMap(), tl.operationStartCount);
assertEquals(Collections.emptyMap(), tl.operationEndCount);
assertEquals(0, tl.preOpExecutionCount);
assertEquals(0, tl.opExecutionCount);
assertEquals(0, tl.activationAvailableCount);
assertEquals(0, tl.preUpdateCount);
}
//Check fit:
tl = new TestListener(Operation.TRAINING);
sd.setListeners(tl);
sd.setTrainingConfig(TrainingConfig.builder()
.dataSetFeatureMapping("in")
.dataSetLabelMapping("label")
.updater(new Adam(1e-3))
.build());
SingletonDataSetIterator dsi = new SingletonDataSetIterator(new DataSet(phMap.get("in"), phMap.get("label")));
for( int i=1; i<=3; i++ ) {
sd.fit(dsi, 1);
assertEquals(i, tl.epochStartCount);
assertEquals(i, tl.epochEndCount);
assertEquals(0, tl.validationDoneCount);
assertEquals(i, tl.iterationStartCount);
assertEquals(i, tl.iterationDoneCount);
assertEquals(Collections.singletonMap(Operation.TRAINING, i), tl.operationStartCount);
assertEquals(Collections.singletonMap(Operation.TRAINING, i), tl.operationEndCount);
assertEquals(7*i, tl.preOpExecutionCount); //mmul, add, softmax, loss grad, softmax backward, add backward, mmul backward
assertEquals(7*i, tl.opExecutionCount);
assertEquals(11*i, tl.activationAvailableCount); //mmul, add, softmax, loss grad (weight, in, label), softmax bp, add backward (z, b), mmul (in, w)
assertEquals(2*i, tl.preUpdateCount); //w, b
}
//Check evaluation:
tl = new TestListener(Operation.EVALUATION);
sd.setListeners(tl);
for( int i=1; i<=3; i++ ) {
sd.evaluate(dsi, "softmax", new Evaluation());
assertEquals(0, tl.epochStartCount);
assertEquals(0, tl.epochEndCount);
assertEquals(0, tl.validationDoneCount);
assertEquals(0, tl.iterationStartCount);
assertEquals(0, tl.iterationDoneCount);
assertEquals(Collections.singletonMap(Operation.EVALUATION, i), tl.operationStartCount);
assertEquals(Collections.singletonMap(Operation.EVALUATION, i), tl.operationEndCount);
assertEquals(3*i, tl.preOpExecutionCount); //mmul, add, softmax
assertEquals(3*i, tl.opExecutionCount);
assertEquals(3*i, tl.activationAvailableCount); //mmul, add, softmax
assertEquals(0, tl.preUpdateCount); //w, b
}
}
private static class TestListener implements Listener {
public TestListener(Operation operation){
this.operation = operation;
}
private final Operation operation;
private int epochStartCount = 0;
private int epochEndCount = 0;
private int validationDoneCount = 0;
private int iterationStartCount = 0;
private int iterationDoneCount = 0;
private Map<Operation,Integer> operationStartCount = new HashMap<>();
private Map<Operation,Integer> operationEndCount = new HashMap<>();
private int preOpExecutionCount = 0;
private int opExecutionCount = 0;
private int activationAvailableCount = 0;
private int preUpdateCount = 0;
@Override
public ListenerVariables requiredVariables(SameDiff sd) {
return null;
}
@Override
public boolean isActive(Operation operation) {
return this.operation == null || this.operation == operation;
}
@Override
public void epochStart(SameDiff sd, At at) {
epochStartCount++;
}
@Override
public ListenerResponse epochEnd(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis) {
epochEndCount++;
return ListenerResponse.CONTINUE;
}
@Override
public ListenerResponse validationDone(SameDiff sd, At at, long validationTimeMillis) {
validationDoneCount++;
return ListenerResponse.CONTINUE;
}
@Override
public void iterationStart(SameDiff sd, At at, MultiDataSet data, long etlTimeMs) {
iterationStartCount++;
}
@Override
public void iterationDone(SameDiff sd, At at, MultiDataSet dataSet, Loss loss) {
iterationDoneCount++;
}
@Override
public void operationStart(SameDiff sd, Operation op) {
if(!operationStartCount.containsKey(op)) {
operationStartCount.put(op, 1);
} else {
operationStartCount.put(op, operationStartCount.get(op) + 1);
}
}
@Override
public void operationEnd(SameDiff sd, Operation op) {
if(!operationEndCount.containsKey(op)) {
operationEndCount.put(op, 1);
} else {
operationEndCount.put(op, operationEndCount.get(op) + 1);
}
}
@Override
public void preOpExecution(SameDiff sd, At at, SameDiffOp op) {
preOpExecutionCount++;
}
@Override
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
opExecutionCount++;
}
@Override
public void activationAvailable(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, String varName, INDArray activation) {
activationAvailableCount++;
}
@Override
public void preUpdate(SameDiff sd, At at, Variable v, INDArray update) {
preUpdateCount++;
}
}
}

View File

@ -20,11 +20,13 @@ import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.ImageMultiPreProcessingScaler;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;
import static org.junit.Assert.assertEquals;
@ -42,32 +44,32 @@ public class ImagePreProcessortTest extends BaseNd4jTest {
@Test
public void simpleImageTest() {
INDArray rChannels = Nd4j.zeros(10, 10).addi(128);
INDArray gChannels = Nd4j.zeros(10, 10).addi(64);
INDArray bChannels = Nd4j.zeros(10, 10).addi(255);
INDArray image = Nd4j.vstack(rChannels, gChannels, bChannels).reshape(3, 10, 10);
INDArray rChannels = Nd4j.zeros(DataType.FLOAT, 10, 10).addi(128);
INDArray gChannels = Nd4j.zeros(DataType.FLOAT, 10, 10).addi(64);
INDArray bChannels = Nd4j.zeros(DataType.FLOAT, 10, 10).addi(255);
INDArray image = Nd4j.vstack(rChannels, gChannels, bChannels).reshape(1, 3, 10, 10);
INDArray orig = image.dup();
//System.out.println(Arrays.toString(image.shape()));
DataSet ds = new DataSet(image.reshape(1, 3, 10, 10), Nd4j.ones(1, 1));
DataSet ds = new DataSet(image, Nd4j.ones(1, 1));
ImagePreProcessingScaler myScaler = new ImagePreProcessingScaler();
//So this should scale to 0.5,0.25 and 1;
INDArray expected = image.mul(0);
expected.slice(0, 0).addi(0.5);
expected.slice(1, 0).addi(0.25);
expected.slice(2, 0).addi(1.0);
expected.slice(0, 1).addi(0.5);
expected.slice(1, 1).addi(0.25);
expected.slice(2, 1).addi(1.0);
myScaler.transform(ds);
assertTrue(Transforms.abs(ds.getFeatures().sub(expected)).maxNumber().doubleValue() <= 0.01);
//Now giving it 16 bits instead of the default
//System.out.println(Arrays.toString(image.shape()));
ds = new DataSet(image.reshape(1, 3, 10, 10), Nd4j.ones(1, 1));
ds = new DataSet(image, Nd4j.ones(1, 1));
myScaler = new ImagePreProcessingScaler(0, 1, 16);
//So this should scale to 0.5,0.25 and 1;
expected = image.mul(0);
expected.slice(0, 0).addi(0.5 / 256);
expected.slice(1, 0).addi(0.25 / 256);
expected.slice(2, 0).addi(1.0 / 256);
expected.slice(0, 1).addi(0.5 / 256);
expected.slice(1, 1).addi(0.25 / 256);
expected.slice(2, 1).addi(1.0 / 256);
myScaler.transform(ds);
assertTrue(Transforms.abs(ds.getFeatures().sub(expected)).maxNumber().doubleValue() <= 0.01);
@ -88,6 +90,16 @@ public class ImagePreProcessortTest extends BaseNd4jTest {
myScaler.transform(before);
myScaler.revertFeatures(before);
assertEquals(orig, before);
//Test labels (segmentation case)
before = orig.dup();
myScaler = new ImagePreProcessingScaler(0, 1);
myScaler.transformLabel(before);
expected = orig.div(255);
assertEquals(expected, before);
myScaler.revertLabels(before);
assertEquals(orig, before);
}
@Test