Various SameDiff fixes (#21)
* MKLDNN LSTM forward implementation (disabled pending #8331) Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8318 add SameDiff.calculateGradientsAndOutputs Signed-off-by: AlexDBlack <blacka101@gmail.com> * Disable mkldnn backprop for now - pending fix, issue #8335 Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8337 Fix CudaExecutioner unnecessary result array allocation/replacement Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small FlatBuffers serde fix, UInt8 Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8135 ImagePreProcessingScaler - add segmentation support Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8319 Ensure listeners are called when they are supposed to be called Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8214 UNet (non-pretrained) last conv layer kernal size fix Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
b816845797
commit
d82877b18b
|
@ -0,0 +1,107 @@
|
|||
package org.deeplearning4j;
|
||||
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.WorkspaceMode;
|
||||
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.nn.layers.mkldnn.MKLDNNBatchNormHelper;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
|
||||
import static junit.framework.TestCase.*;
|
||||
|
||||
public class TestBatchNormBp {
|
||||
|
||||
@Test
|
||||
public void test(){
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
// INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 4, 4);
|
||||
INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 15, 15);
|
||||
INDArray mean = in.mean(0, 2, 3); //Nd4j.rand(DataType.FLOAT, 3);
|
||||
INDArray var = in.var(0, 2, 3); //Nd4j.rand(DataType.FLOAT, 3);
|
||||
INDArray eps = Nd4j.rand(DataType.FLOAT, in.shape());
|
||||
// INDArray gamma = Nd4j.ones(DataType.FLOAT, 3);
|
||||
// INDArray beta = Nd4j.zeros(DataType.FLOAT, 3);
|
||||
INDArray gamma = Nd4j.rand(DataType.FLOAT, 3);
|
||||
INDArray beta = Nd4j.rand(DataType.FLOAT, 3);
|
||||
double e = 1e-5;
|
||||
|
||||
INDArray dLdIn = in.ulike();
|
||||
INDArray dLdm = mean.ulike();
|
||||
INDArray dLdv = var.ulike();
|
||||
INDArray dLdg = gamma.ulike();
|
||||
INDArray dLdb = beta.ulike();
|
||||
|
||||
DynamicCustomOp op = DynamicCustomOp.builder("batchnorm_bp")
|
||||
.addInputs(in, mean, var, eps, gamma, beta)
|
||||
.addIntegerArguments(
|
||||
1, //Apply scale
|
||||
1, //Apply beta
|
||||
1) //Axis (NCHW)
|
||||
.addFloatingPointArguments(e)
|
||||
.addOutputs(dLdIn, dLdm, dLdv, dLdg, dLdb)
|
||||
.build();
|
||||
|
||||
Nd4j.exec(op);
|
||||
System.out.println(dLdIn);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void compareImpls() throws Exception {
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 15, 15);
|
||||
INDArray mean = in.mean(0, 2, 3).reshape(1,3);
|
||||
INDArray var = in.var(0, 2, 3).reshape(1,3);
|
||||
INDArray eps = Nd4j.rand(DataType.FLOAT, in.shape());
|
||||
INDArray gamma = Nd4j.rand(DataType.FLOAT, 1,3);
|
||||
INDArray beta = Nd4j.rand(DataType.FLOAT, 1,3);
|
||||
double e = 1e-3;
|
||||
|
||||
INDArray dLdIn = in.ulike();
|
||||
INDArray dLdm = mean.ulike();
|
||||
INDArray dLdv = var.ulike();
|
||||
INDArray dLdg = gamma.ulike();
|
||||
INDArray dLdb = beta.ulike();
|
||||
|
||||
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.inferenceWorkspaceMode(WorkspaceMode.NONE)
|
||||
.trainingWorkspaceMode(WorkspaceMode.NONE)
|
||||
.list()
|
||||
.layer(new BatchNormalization.Builder().nIn(3).nOut(3).build())
|
||||
.build();
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
org.deeplearning4j.nn.layers.normalization.BatchNormalization bn = (org.deeplearning4j.nn.layers.normalization.BatchNormalization) net.getLayer(0);
|
||||
assertNotNull(bn.getHelper());
|
||||
Field f = bn.getClass().getDeclaredField("helper");
|
||||
f.setAccessible(true);
|
||||
f.set(bn, null);
|
||||
assertNull(bn.getHelper());
|
||||
|
||||
|
||||
MKLDNNBatchNormHelper h = new MKLDNNBatchNormHelper(DataType.FLOAT);
|
||||
|
||||
net.output(in, true);
|
||||
bn.setInput(in, LayerWorkspaceMgr.noWorkspaces());
|
||||
Pair<Gradient,INDArray> p = net.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces());
|
||||
|
||||
h.preOutput(in, true, new int[]{1,3}, gamma, beta, mean, var, 0.5, e, LayerWorkspaceMgr.noWorkspaces());
|
||||
Pair<Gradient,INDArray> pmkl = h.backpropGradient(in, eps, new int[]{1,3}, gamma, beta, dLdg, dLdb, e, LayerWorkspaceMgr.noWorkspaces());
|
||||
|
||||
INDArray dldin_dl4j = p.getSecond();
|
||||
|
||||
System.out.println("dl4j == mkldnn: " + p.getSecond().equals(pmkl.getSecond()));
|
||||
}
|
||||
|
||||
}
|
|
@ -23,10 +23,13 @@ import org.deeplearning4j.datasets.iterator.impl.SingletonDataSetIterator;
|
|||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.WorkspaceMode;
|
||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.*;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
|
@ -36,10 +39,13 @@ import org.nd4j.linalg.dataset.DataSet;
|
|||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.learning.config.Adam;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
|
||||
import static junit.framework.TestCase.*;
|
||||
import static org.junit.Assume.assumeTrue;
|
||||
|
||||
public class ValidateMKLDNN extends BaseDL4JTest {
|
||||
|
@ -148,7 +154,7 @@ public class ValidateMKLDNN extends BaseDL4JTest {
|
|||
.padding(0, 0)
|
||||
.nOut(3)
|
||||
.build())
|
||||
.layer(new BatchNormalization.Builder().cudnnAllowFallback(false).build())
|
||||
.layer(new BatchNormalization.Builder().helperAllowFallback(false)/*.eps(0)*/.build())
|
||||
.layer(new ConvolutionLayer.Builder().activation(Activation.TANH)
|
||||
.kernelSize(kernel)
|
||||
.stride(stride)
|
||||
|
@ -256,4 +262,54 @@ public class ValidateMKLDNN extends BaseDL4JTest {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void compareBatchNormBackward() throws Exception {
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 15, 15);
|
||||
INDArray mean = in.mean(0, 2, 3).reshape(1,3);
|
||||
INDArray var = in.var(0, 2, 3).reshape(1,3);
|
||||
INDArray eps = Nd4j.rand(DataType.FLOAT, in.shape());
|
||||
INDArray gamma = Nd4j.rand(DataType.FLOAT, 1,3);
|
||||
INDArray beta = Nd4j.rand(DataType.FLOAT, 1,3);
|
||||
double e = 1e-3;
|
||||
|
||||
INDArray dLdIn = in.ulike();
|
||||
INDArray dLdm = mean.ulike();
|
||||
INDArray dLdv = var.ulike();
|
||||
INDArray dLdg = gamma.ulike();
|
||||
INDArray dLdb = beta.ulike();
|
||||
|
||||
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.inferenceWorkspaceMode(WorkspaceMode.NONE)
|
||||
.trainingWorkspaceMode(WorkspaceMode.NONE)
|
||||
.list()
|
||||
.layer(new BatchNormalization.Builder().nIn(3).nOut(3).build())
|
||||
.build();
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
org.deeplearning4j.nn.layers.normalization.BatchNormalization bn = (org.deeplearning4j.nn.layers.normalization.BatchNormalization) net.getLayer(0);
|
||||
assertNotNull(bn.getHelper());
|
||||
System.out.println(bn.getHelper());
|
||||
|
||||
net.output(in, true);
|
||||
bn.setInput(in, LayerWorkspaceMgr.noWorkspaces());
|
||||
Pair<Gradient,INDArray> pcudnn = net.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces());
|
||||
|
||||
Field f = bn.getClass().getDeclaredField("helper");
|
||||
f.setAccessible(true);
|
||||
f.set(bn, null);
|
||||
assertNull(bn.getHelper());
|
||||
|
||||
net.output(in, true);
|
||||
bn.setInput(in, LayerWorkspaceMgr.noWorkspaces());
|
||||
Pair<Gradient,INDArray> p = net.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces());
|
||||
|
||||
INDArray dldin_dl4j = p.getSecond();
|
||||
INDArray dldin_helper = pcudnn.getSecond();
|
||||
|
||||
assertTrue(dldin_dl4j.equalsWithEps(dldin_helper, 1e-5));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -123,7 +123,7 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
|
|||
}
|
||||
|
||||
@Override
|
||||
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, int[] shape, INDArray gamma,
|
||||
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, int[] shape, INDArray gamma, INDArray beta,
|
||||
INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr layerWorkspaceMgr) {
|
||||
this.eps = eps;
|
||||
val miniBatch = (int) input.size(0);
|
||||
|
@ -189,7 +189,7 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
|
|||
Pointer varCacheData = allocator.getPointer(varCache, context);
|
||||
|
||||
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
|
||||
checkCudnn(cudnnBatchNormalizationBackward(cudnnContext, batchNormMode, alpha, beta, alpha, alpha,
|
||||
checkCudnn(cudnnBatchNormalizationBackward(cudnnContext, batchNormMode, alpha, this.beta, alpha, alpha,
|
||||
cudnnContext.srcTensorDesc, srcData, cudnnContext.deltaTensorDesc, epsData,
|
||||
cudnnContext.dstTensorDesc, dstData, cudnnContext.gammaBetaTensorDesc, gammaData, dGammaData,
|
||||
dBetaData, eps, meanCacheData, varCacheData));
|
||||
|
|
|
@ -16,21 +16,28 @@
|
|||
|
||||
package org.deeplearning4j.nn.layers.mkldnn;
|
||||
|
||||
import org.deeplearning4j.nn.gradient.DefaultGradient;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper;
|
||||
import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer;
|
||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.OpContext;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNormDerivative;
|
||||
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
|
||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
|
@ -57,27 +64,53 @@ public class MKLDNNBatchNormHelper implements BatchNormalizationHelper {
|
|||
|
||||
@Override
|
||||
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma,
|
||||
INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr workspaceMgr) {
|
||||
//2019-02-14: Backprop disabled pending fixes. https://github.com/deeplearning4j/deeplearning4j/issues/7166
|
||||
//Also no MKL-DNN implemented for backprop anyway
|
||||
INDArray beta, INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr workspaceMgr) {
|
||||
if(input.dataType() != DataType.FLOAT)
|
||||
return null; //MKL-DNN only supports float
|
||||
/*
|
||||
INDArray[] in = gamma == null ? new INDArray[]{input, mean, var, epsilon} : new INDArray[]{input, mean, var, gamma, beta, epsilon};
|
||||
//TODO FIXME - AB 2019/11/01 - https://github.com/eclipse/deeplearning4j/issues/8335
|
||||
List<INDArray> args = new ArrayList<>();
|
||||
args.add(input);
|
||||
args.add(meanCache);
|
||||
args.add(varCache);
|
||||
args.add(epsilon);
|
||||
if(gamma != null)
|
||||
args.add(gamma.reshape(gamma.length()));
|
||||
if(beta != null)
|
||||
args.add(beta.reshape(beta.length()));
|
||||
|
||||
INDArray gradAtInput = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), input.shape());
|
||||
|
||||
INDArray[] out = gamma == null ? new INDArray[]{gradAtInput, }
|
||||
|
||||
BatchNormDerivative bn = BatchNormDerivative.derivativeBuilder()
|
||||
.applyBeta(gamma != null)
|
||||
.applyGamma(gamma != null)
|
||||
.axis(new int[]{1}) //4d: is channels: NCHW; 2d: is nIn - axis 1 in both cases
|
||||
.epsilon(eps)
|
||||
.inputArrays(in)
|
||||
.outputArrays(new INDArray[]{out})
|
||||
DynamicCustomOp op = DynamicCustomOp.builder("batchnorm_bp")
|
||||
.addInputs(args.toArray(new INDArray[0]))
|
||||
.addIntegerArguments(
|
||||
gamma == null ? 0 : 1, //Apply scale
|
||||
beta == null ? 0 : 1, //Apply beta
|
||||
1) //Axis (NCHW)
|
||||
.addFloatingPointArguments(eps)
|
||||
.build();
|
||||
Nd4j.exec(bn);
|
||||
*/
|
||||
|
||||
INDArray epsAtInput = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape());
|
||||
INDArray dLdm = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, meanCache.dataType(), meanCache.shape());
|
||||
INDArray dLdv = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, meanCache.dataType(), meanCache.shape());
|
||||
|
||||
op.setOutputArgument(0, epsAtInput);
|
||||
op.setOutputArgument(1, dLdm);
|
||||
op.setOutputArgument(2, dLdv);
|
||||
if(dGammaView != null) {
|
||||
//Both are always null/not null simultaneously
|
||||
op.setOutputArgument(3, dGammaView.reshape(dGammaView.length()));
|
||||
op.setOutputArgument(4, dBetaView.reshape(dBetaView.length()));
|
||||
}
|
||||
|
||||
|
||||
Nd4j.exec(op);
|
||||
|
||||
Gradient g = new DefaultGradient();
|
||||
g.setGradientFor(BatchNormalizationParamInitializer.GAMMA, dGammaView);
|
||||
g.setGradientFor(BatchNormalizationParamInitializer.BETA, dBetaView);
|
||||
|
||||
return new Pair<>(g, epsAtInput);
|
||||
*/
|
||||
return null;
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,168 @@
|
|||
package org.deeplearning4j.nn.layers.mkldnn;
|
||||
|
||||
import org.deeplearning4j.nn.api.Layer;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.layers.LSTM;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.nn.layers.recurrent.FwdPassReturn;
|
||||
import org.deeplearning4j.nn.layers.recurrent.LSTMHelper;
|
||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.nd4j.linalg.activations.IActivation;
|
||||
import org.nd4j.linalg.activations.impl.*;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.indexing.BooleanIndexing;
|
||||
import org.nd4j.linalg.indexing.conditions.Conditions;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class MKLDNNLSTMHelper implements LSTMHelper {
|
||||
@Override
|
||||
public boolean checkSupported(IActivation gateActivationFn, IActivation activationFn, boolean hasPeepholeConnections) {
|
||||
//TODO check other activation functions for MKLDNN
|
||||
return gateActivationFn instanceof ActivationSigmoid && activationFn instanceof ActivationTanH && BaseMKLDNNHelper.mklDnnEnabled();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Pair<Gradient, INDArray> backpropGradient(NeuralNetConfiguration conf, IActivation gateActivationFn, INDArray input,
|
||||
INDArray recurrentWeights, INDArray inputWeights, INDArray epsilon, boolean truncatedBPTT,
|
||||
int tbpttBackwardLength, FwdPassReturn fwdPass, boolean forwards, String inputWeightKey,
|
||||
String recurrentWeightKey, String biasWeightKey, Map<String, INDArray> gradientViews,
|
||||
INDArray maskArray, boolean hasPeepholeConnections, LayerWorkspaceMgr workspaceMgr) {
|
||||
//Not yet implemented/supported
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public FwdPassReturn activate(Layer layer, NeuralNetConfiguration conf, IActivation gateActivationFn, INDArray input,
|
||||
INDArray recurrentWeights, INDArray inputWeights, INDArray biases, boolean training,
|
||||
INDArray prevOutputActivations, INDArray prevMemCellState, boolean forBackprop, boolean forwards,
|
||||
String inputWeightKey, INDArray maskArray, boolean hasPeepholeConnections, LayerWorkspaceMgr workspaceMgr) {
|
||||
|
||||
/*
|
||||
DL4J data format: [bS, nIn, sL] - dataFormat == 2, directionMode == 0 (forward)
|
||||
Inputs:
|
||||
x = [bS, nIn, sL]
|
||||
Wx = [nIn, 4*nOut]
|
||||
Wr = [nOut, 4*nOut]
|
||||
Wp = [3*nOut] Optional peephole weights
|
||||
b = [4*nOut]
|
||||
seqLen = [bS]
|
||||
initialOut = [bs, nOut]
|
||||
initialCell = [bs, nOut]
|
||||
|
||||
Outputs:
|
||||
out = [bS, nOut, sL]
|
||||
outLast = [bs, nOut]
|
||||
cellLast = [bs,nOut]
|
||||
|
||||
Gates order: input, forget, input modulation, output
|
||||
|
||||
|
||||
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
|
||||
const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided
|
||||
const auto hasInitH = B_ARG(2); // indicates whether initial output is provided
|
||||
const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided
|
||||
const auto hasPH = B_ARG(4); // indicates whether peephole connections are present
|
||||
const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1}
|
||||
const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
|
||||
const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
|
||||
*/
|
||||
|
||||
INDArray b1d = biases.reshape(biases.length());
|
||||
INDArray seqLen = null;
|
||||
if(maskArray != null){
|
||||
seqLen = BooleanIndexing.firstIndex(maskArray, Conditions.equals(0), 1); //First 0 along dimension 1 (for [mb, seqLen])
|
||||
}
|
||||
|
||||
List<INDArray> args = new ArrayList<>();
|
||||
args.add(input);
|
||||
args.add(inputWeights);
|
||||
args.add(recurrentWeights);
|
||||
if(hasPeepholeConnections){
|
||||
throw new IllegalStateException("Not yet implemented");
|
||||
}
|
||||
args.add(b1d);
|
||||
if(seqLen != null)
|
||||
args.add(seqLen);
|
||||
if(prevOutputActivations != null)
|
||||
args.add(prevOutputActivations);
|
||||
if(prevMemCellState != null)
|
||||
args.add(prevMemCellState);
|
||||
|
||||
IActivation a = ((LSTM)conf.getLayer()).getActivationFn();
|
||||
|
||||
DynamicCustomOp op = DynamicCustomOp.builder("lstmLayer")
|
||||
.addInputs(args.toArray(new INDArray[0]))
|
||||
.addBooleanArguments(
|
||||
true, //hasBiases
|
||||
seqLen != null, //hasSeqLen
|
||||
prevOutputActivations != null, //hasInitH
|
||||
prevMemCellState != null, //hasInitC
|
||||
hasPeepholeConnections, //hasPh
|
||||
true, //retFullSeq
|
||||
true, //retLastH
|
||||
true //retLastC
|
||||
)
|
||||
.addIntegerArguments(
|
||||
2, //data format: 2 = [bS, nIn, sL]
|
||||
0, //direction: 0 = forward
|
||||
activationToArg(gateActivationFn), //Gate activation
|
||||
activationToArg(a), //Cell state activation
|
||||
activationToArg(a) //Output activation (same as cell in DL4J)
|
||||
)
|
||||
.build();
|
||||
|
||||
List<LongShapeDescriptor> outShapes = op.calculateOutputShape();
|
||||
|
||||
for(LongShapeDescriptor lsd : outShapes){
|
||||
INDArray arr = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, lsd.dataType(), lsd.getShape(), lsd.getOrder());
|
||||
op.addOutputArgument(arr);
|
||||
}
|
||||
|
||||
FwdPassReturn f = new FwdPassReturn();
|
||||
f.fwdPassOutput = op.getOutputArgument(0);
|
||||
f.lastAct = op.getOutputArgument(1);
|
||||
f.lastMemCell = op.getOutputArgument(2);
|
||||
|
||||
return f;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Long> helperMemoryUse() {
|
||||
return Collections.emptyMap();
|
||||
}
|
||||
|
||||
private int activationToArg(IActivation a){
|
||||
//0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
|
||||
if(a instanceof ActivationTanH)
|
||||
return 0;
|
||||
if(a instanceof ActivationReLU)
|
||||
return 1;
|
||||
if(a instanceof ActivationSigmoid)
|
||||
return 2;
|
||||
if(a instanceof ActivationIdentity)
|
||||
return 3;
|
||||
if(a instanceof ActivationLReLU)
|
||||
return 4;
|
||||
if(a instanceof ActivationThresholdedReLU)
|
||||
return 5;
|
||||
if(a instanceof ActivationHardSigmoid)
|
||||
return 7;
|
||||
if(a instanceof ActivationELU)
|
||||
return 8;
|
||||
if(a instanceof ActivationSoftSign)
|
||||
return 9;
|
||||
if(a instanceof ActivationSoftPlus)
|
||||
return 10;
|
||||
throw new IllegalStateException("Unknown or not supported activation function: " + a);
|
||||
}
|
||||
}
|
|
@ -118,6 +118,7 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
|
|||
INDArray globalVar = params.get(BatchNormalizationParamInitializer.GLOBAL_VAR); //One of log10std will be null depending on config
|
||||
INDArray globalLog10Std = params.get(BatchNormalizationParamInitializer.GLOBAL_LOG_STD);
|
||||
INDArray gamma = null;
|
||||
INDArray beta = null;
|
||||
INDArray dGammaView;
|
||||
INDArray dBetaView;
|
||||
INDArray dGlobalMeanView = gradientViews.get(BatchNormalizationParamInitializer.GLOBAL_MEAN);
|
||||
|
@ -129,6 +130,7 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
|
|||
dBetaView = Nd4j.createUninitialized(dataType, tempShape, 'c');
|
||||
} else {
|
||||
gamma = getParam(BatchNormalizationParamInitializer.GAMMA);
|
||||
beta = getParam(BatchNormalizationParamInitializer.BETA);
|
||||
dGammaView = gradientViews.get(BatchNormalizationParamInitializer.GAMMA);
|
||||
dBetaView = gradientViews.get(BatchNormalizationParamInitializer.BETA);
|
||||
}
|
||||
|
@ -154,12 +156,12 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
|
|||
|
||||
Pair<Gradient,INDArray> ret = null;
|
||||
try {
|
||||
ret = helper.backpropGradient(in, eps, shape, gamma, dGammaView, dBetaView,
|
||||
ret = helper.backpropGradient(in, eps, shape, gamma, beta, dGammaView, dBetaView,
|
||||
layerConf.getEps(), workspaceMgr);
|
||||
} catch (ND4JOpProfilerException e){
|
||||
throw e; //NaN panic etc for debugging
|
||||
} catch (Throwable t){
|
||||
if(t.getMessage().contains("Failed to allocate")){
|
||||
if(t.getMessage() != null && t.getMessage().contains("Failed to allocate")){
|
||||
//This is a memory exception - don't fallback to built-in implementation
|
||||
throw t;
|
||||
}
|
||||
|
@ -451,7 +453,7 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
|
|||
} catch (ND4JOpProfilerException e){
|
||||
throw e; //NaN panic etc for debugging
|
||||
} catch (Throwable t) {
|
||||
if(t.getMessage().contains("Failed to allocate")){
|
||||
if(t.getMessage() != null && t.getMessage().contains("Failed to allocate")){
|
||||
//This is a memory exception - don't fallback to built-in implementation
|
||||
throw t;
|
||||
}
|
||||
|
|
|
@ -31,8 +31,8 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
|||
public interface BatchNormalizationHelper extends LayerHelper {
|
||||
boolean checkSupported(double eps, boolean fixedGammaBeta);
|
||||
|
||||
Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma,
|
||||
INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr workspaceMgr);
|
||||
Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, INDArray beta,
|
||||
INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr workspaceMgr);
|
||||
|
||||
INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean,
|
||||
INDArray var, double decay, double eps, LayerWorkspaceMgr workspaceMgr);
|
||||
|
|
|
@ -144,7 +144,7 @@ public class LocalResponseNormalization
|
|||
} catch (ND4JOpProfilerException e){
|
||||
throw e; //NaN panic etc for debugging
|
||||
} catch (Throwable t){
|
||||
if(t.getMessage().contains("Failed to allocate")){
|
||||
if(t.getMessage() != null && t.getMessage().contains("Failed to allocate")){
|
||||
//This is a memory exception - don't fallback to built-in implementation
|
||||
throw t;
|
||||
}
|
||||
|
@ -211,7 +211,7 @@ public class LocalResponseNormalization
|
|||
} catch (ND4JOpProfilerException e){
|
||||
throw e; //NaN panic etc for debugging
|
||||
} catch (Throwable t){
|
||||
if(t.getMessage().contains("Failed to allocate")){
|
||||
if(t.getMessage() != null && t.getMessage().contains("Failed to allocate")){
|
||||
//This is a memory exception - don't fallback to built-in implementation
|
||||
throw t;
|
||||
}
|
||||
|
|
|
@ -22,6 +22,8 @@ import org.deeplearning4j.nn.conf.CacheMode;
|
|||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.nn.layers.LayerHelper;
|
||||
import org.deeplearning4j.nn.layers.mkldnn.BaseMKLDNNHelper;
|
||||
import org.deeplearning4j.nn.layers.mkldnn.MKLDNNLSTMHelper;
|
||||
import org.deeplearning4j.nn.params.LSTMParamInitializer;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
|
@ -73,6 +75,16 @@ public class LSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.L
|
|||
}
|
||||
}
|
||||
}
|
||||
/*
|
||||
//Disabled pending: https://github.com/eclipse/deeplearning4j/issues/8331
|
||||
else if ("CPU".equalsIgnoreCase(backend) && BaseMKLDNNHelper.mklDnnEnabled()){
|
||||
helper = new MKLDNNLSTMHelper();
|
||||
log.debug("MKLDNNLSTMHelper successfully initialized");
|
||||
if (!helper.checkSupported(layerConf().getGateActivationFn(), layerConf().getActivationFn(), false)) {
|
||||
helper = null;
|
||||
}
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -215,7 +215,7 @@ public class UNet extends ZooModel {
|
|||
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
|
||||
.activation(Activation.RELU).build(), "conv9-2")
|
||||
|
||||
.addLayer("conv10", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(1)
|
||||
.addLayer("conv10", new ConvolutionLayer.Builder(1,1).stride(1,1).nOut(1)
|
||||
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
|
||||
.activation(Activation.IDENTITY).build(), "conv9-3")
|
||||
.addLayer("output", new CnnLossLayer.Builder(LossFunctions.LossFunction.XENT)
|
||||
|
|
|
@ -122,7 +122,7 @@ public interface Listener {
|
|||
/**
|
||||
* Called when any activation becomes available.
|
||||
* <p>
|
||||
* The activation will most likely be freed later, use detach() if you need to save it.<br>
|
||||
* The activation will most likely be freed later, use dup() if you need to save it.<br>
|
||||
* <br>
|
||||
* Note that this method will be called when any activation becomes available, not just ones from {@link #requiredVariables(SameDiff)}<br>
|
||||
* It is guaranteed to be called for variables from requiredVariables().<br>
|
||||
|
|
|
@ -29,6 +29,7 @@ import org.nd4j.autodiff.listeners.*;
|
|||
import org.nd4j.autodiff.listeners.impl.HistoryListener;
|
||||
import org.nd4j.autodiff.listeners.records.History;
|
||||
import org.nd4j.autodiff.listeners.records.LossCurve;
|
||||
import org.nd4j.autodiff.samediff.api.OutAndGrad;
|
||||
import org.nd4j.autodiff.samediff.config.BatchOutputConfig;
|
||||
import org.nd4j.autodiff.samediff.config.EvaluationConfig;
|
||||
import org.nd4j.autodiff.samediff.config.FitConfig;
|
||||
|
@ -1642,7 +1643,13 @@ public class SameDiff extends SDBaseOps {
|
|||
|
||||
Set<String> requiredVars = new HashSet<>();
|
||||
for (Listener l : activeListeners) {
|
||||
requiredVars.addAll(l.requiredVariables(this).trainingVariables());
|
||||
ListenerVariables lv = l.requiredVariables(this);
|
||||
if(lv != null) {
|
||||
Set<String> s = lv.trainingVariables();
|
||||
if(s != null) {
|
||||
requiredVars.addAll(s);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
List<Listener> listenersWitHistory = new ArrayList<>(listeners);
|
||||
|
@ -1661,6 +1668,10 @@ public class SameDiff extends SDBaseOps {
|
|||
TrainingSession ts = new TrainingSession(gradInstance);
|
||||
gradInstance.setTrainingConfig(trainingConfig); //In case any listeners want to use it
|
||||
|
||||
for(Listener l : activeListeners){
|
||||
l.operationStart(gradInstance, Operation.TRAINING);
|
||||
}
|
||||
|
||||
Set<String> paramsToTrain = new LinkedHashSet<>();
|
||||
for(Variable v : variables.values()){
|
||||
if(v.getVariable().getVariableType() == VariableType.VARIABLE){
|
||||
|
@ -1844,9 +1855,12 @@ public class SameDiff extends SDBaseOps {
|
|||
*/
|
||||
private void validateListenerActivations(List<Listener> listeners, Operation op) {
|
||||
for (Listener l : listeners) {
|
||||
for (String s : l.requiredVariables(this).requiredVariables(op)) {
|
||||
if (!variables.containsKey(s)) {
|
||||
Preconditions.checkState(false, "Listener %s requested variable %s that is not defined in this SameDiff graph", l, s);
|
||||
ListenerVariables lv = l.requiredVariables(this);
|
||||
if(lv != null) {
|
||||
for (String s : lv.requiredVariables(op)) {
|
||||
if (!variables.containsKey(s)) {
|
||||
Preconditions.checkState(false, "Listener %s requested variable %s that is not defined in this SameDiff graph", l, s);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2151,31 +2165,20 @@ public class SameDiff extends SDBaseOps {
|
|||
|
||||
if (hasListeners) {
|
||||
for (Listener l : activeListeners) {
|
||||
requiredVars.addAll(l.requiredVariables(this).evaluationVariables());
|
||||
ListenerVariables v = l.requiredVariables(this);
|
||||
if(v != null) {
|
||||
requiredVars.addAll(v.evaluationVariables());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
String[] requiredVarsArr = requiredVars.toArray(new String[0]);
|
||||
|
||||
while (iterator.hasNext()) {
|
||||
long dataStart = hasListeners ? System.currentTimeMillis() : 0;
|
||||
MultiDataSet ds = iterator.next();
|
||||
long dataEnd = hasListeners ? System.currentTimeMillis() : 0;
|
||||
Map<String, INDArray> placeholderMap = toPlaceholderMap(ds);
|
||||
|
||||
Map<String, INDArray> m;
|
||||
Map<String, INDArray> outs = null;
|
||||
if (hasListeners) {
|
||||
|
||||
for (Listener l : activeListeners) {
|
||||
l.iterationStart(this, at, ds, (dataEnd - dataStart));
|
||||
}
|
||||
|
||||
m = directExecHelper(placeholderMap, at, ds, Collections.<String>emptyList(), activeListeners, requiredVarsArr);
|
||||
} else {
|
||||
m = directExecHelper(placeholderMap, at, ds, Collections.<String>emptyList(), activeListeners, requiredVarsArr);
|
||||
}
|
||||
|
||||
Map<String, INDArray> m = directExecHelper(placeholderMap, at, ds, Collections.<String>emptyList(), activeListeners, requiredVarsArr);
|
||||
|
||||
for (Map.Entry<String, List<IEvaluation>> e : variableEvals.entrySet()) {
|
||||
INDArray prediction = m.get(e.getKey());
|
||||
|
@ -2188,15 +2191,6 @@ public class SameDiff extends SDBaseOps {
|
|||
}
|
||||
}
|
||||
|
||||
if (hasListeners) {
|
||||
for (Listener l : activeListeners) {
|
||||
Map<String, INDArray> outVars = Maps.newHashMap(
|
||||
Maps.filterKeys(outs,
|
||||
Predicates.in(l.requiredVariables(this).evaluationVariables())));
|
||||
l.iterationDone(this, at, ds, null);
|
||||
}
|
||||
}
|
||||
|
||||
at.setIteration(at.iteration() + 1);
|
||||
}
|
||||
|
||||
|
@ -2518,7 +2512,7 @@ public class SameDiff extends SDBaseOps {
|
|||
* Special case of {@link #batchOutput()}.
|
||||
*/
|
||||
public Map<String, INDArray> output(Map<String, INDArray> placeholders, @NonNull List<String> outputs) {
|
||||
return batchOutput().output(outputs.toArray(new String[0])).inputs(placeholders).exec();
|
||||
return batchOutput().output(outputs.toArray(new String[0])).inputs(placeholders).output();
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -2529,7 +2523,7 @@ public class SameDiff extends SDBaseOps {
|
|||
* Special case of {@link #batchOutput()}.
|
||||
*/
|
||||
public Map<String, INDArray> output(Map<String, INDArray> placeholders, String... outputs) {
|
||||
return batchOutput().output(outputs).inputs(placeholders).exec();
|
||||
return batchOutput().output(outputs).inputs(placeholders).output();
|
||||
}
|
||||
|
||||
|
||||
|
@ -2542,31 +2536,36 @@ public class SameDiff extends SDBaseOps {
|
|||
* @param listeners Additional listeners to use during this operation.
|
||||
* @param outputs The variables to output and return.
|
||||
*/
|
||||
public Map<String, INDArray> output(Map<String, INDArray> placeholders, @NonNull List<Listener> listeners, String... outputs) {
|
||||
return batchOutputHelper(placeholders, listeners, outputs);
|
||||
public Map<String, INDArray> output(Map<String, INDArray> placeholders, List<Listener> listeners, String... outputs) {
|
||||
return batchOutputHelper(placeholders, listeners, Operation.INFERENCE, outputs);
|
||||
}
|
||||
|
||||
protected Map<String, INDArray> batchOutputHelper(Map<String, INDArray> placeholders, @NonNull List<Listener> listeners, String... outputs) {
|
||||
protected Map<String, INDArray> batchOutputHelper(Map<String, INDArray> placeholders, List<Listener> listeners, Operation operation, String... outputs) {
|
||||
List<Listener> activeListeners = new ArrayList<>();
|
||||
|
||||
if(operation == null)
|
||||
operation = Operation.INFERENCE;
|
||||
|
||||
for (Listener l : this.listeners)
|
||||
if (l.isActive(Operation.INFERENCE))
|
||||
if (l.isActive(operation))
|
||||
activeListeners.add(l);
|
||||
|
||||
for (Listener l : listeners)
|
||||
if (l.isActive(Operation.INFERENCE))
|
||||
activeListeners.add(l);
|
||||
|
||||
for (Listener l : activeListeners) {
|
||||
l.operationStart(this, Operation.INFERENCE);
|
||||
if(listeners != null) {
|
||||
for (Listener l : listeners)
|
||||
if (l.isActive(operation))
|
||||
activeListeners.add(l);
|
||||
}
|
||||
|
||||
validateListenerActivations(activeListeners, Operation.INFERENCE);
|
||||
for (Listener l : activeListeners) {
|
||||
l.operationStart(this, operation);
|
||||
}
|
||||
|
||||
Map<String, INDArray> ret = directExecHelper(placeholders, At.defaultAt(Operation.INFERENCE), null, Collections.<String>emptyList(), activeListeners, outputs);
|
||||
validateListenerActivations(activeListeners, operation);
|
||||
|
||||
Map<String, INDArray> ret = directExecHelper(placeholders, At.defaultAt(operation), null, Collections.<String>emptyList(), activeListeners, outputs);
|
||||
|
||||
for (Listener l : activeListeners) {
|
||||
l.operationEnd(this, Operation.INFERENCE);
|
||||
l.operationEnd(this, operation);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
@ -3992,7 +3991,6 @@ public class SameDiff extends SDBaseOps {
|
|||
|
||||
sameDiffFunctionInstances.put(function, sub);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -4012,32 +4010,64 @@ public class SameDiff extends SDBaseOps {
|
|||
*/
|
||||
public Map<String, INDArray> calculateGradients(Map<String, INDArray> placeholderVals, @NonNull Collection<String> variables) {
|
||||
Preconditions.checkArgument(!variables.isEmpty(), "No variables were specified");
|
||||
OutAndGrad oag = calculateGradientsAndOutputs(placeholderVals, null, variables);
|
||||
return oag.getGradients();
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate the activations and the gradients for the specified variables, in one execution call.
|
||||
* This is equivalent to calling {@link #output(Map, List)} and {@link #calculateGradients(Map, Collection)}, but
|
||||
* is more efficient than calling both separately.
|
||||
*
|
||||
* @param placeholderVals Placeholders. May be null
|
||||
* @param outputVars Names of the variables that you want the activations/outputs for. May be null
|
||||
* @param gradientVars Names of the variables that you want the gradient arrays for. May be null
|
||||
* @return Activations and gradients, keyed by variable name
|
||||
*/
|
||||
public OutAndGrad calculateGradientsAndOutputs(Map<String,INDArray> placeholderVals, Collection<String> outputVars, Collection<String> gradientVars){
|
||||
Preconditions.checkArgument((outputVars != null && !outputVars.isEmpty()) || (gradientVars != null && !gradientVars.isEmpty()),
|
||||
"No variables were specified for either output or gradients");
|
||||
if (getFunction(GRAD_FN_KEY) == null) {
|
||||
createGradFunction();
|
||||
}
|
||||
|
||||
List<String> gradVarNames = new ArrayList<>(variables.size());
|
||||
for (String s : variables) {
|
||||
Preconditions.checkState(this.variables.containsKey(s), "No variable with name \"%s\" exists in the SameDiff instance", s);
|
||||
SDVariable v = getVariable(s).getGradient();
|
||||
if (v != null) {
|
||||
//In a few cases (like loss not depending on trainable parameters) we won't have gradient array for parameter variable
|
||||
gradVarNames.add(v.name());
|
||||
List<String> varNames = new ArrayList<>();
|
||||
if(outputVars != null){
|
||||
varNames.addAll(outputVars);
|
||||
}
|
||||
if(gradientVars != null) {
|
||||
for (String s : gradientVars) {
|
||||
Preconditions.checkState(this.variables.containsKey(s), "No variable with name \"%s\" exists in the SameDiff instance", s);
|
||||
SDVariable v = getVariable(s).getGradient();
|
||||
if (v != null) {
|
||||
//In a few cases (like loss not depending on trainable parameters) we won't have gradient array for parameter variable
|
||||
varNames.add(v.name());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//Key is gradient variable name
|
||||
Map<String, INDArray> grads = getFunction(GRAD_FN_KEY).output(placeholderVals, gradVarNames);
|
||||
SameDiff gradFn = getFunction(GRAD_FN_KEY);
|
||||
gradFn.setListeners(listeners);
|
||||
Map<String, INDArray> grads = gradFn.batchOutputHelper(placeholderVals, null, Operation.TRAINING, varNames.toArray(new String[0]));
|
||||
|
||||
Map<String, INDArray> out = new HashMap<>();
|
||||
for (String s : variables) {
|
||||
if (getVariable(s).getGradient() != null) {
|
||||
String gradVar = getVariable(s).getGradient().name();
|
||||
out.put(s, grads.get(gradVar));
|
||||
Map<String, INDArray> outOutputs = outputVars == null ? null : new HashMap<String,INDArray>();
|
||||
Map<String, INDArray> outGrads = gradientVars == null ? null : new HashMap<String,INDArray>();
|
||||
if(outputVars != null){
|
||||
for(String s : outputVars){
|
||||
outOutputs.put(s, grads.get(s));
|
||||
}
|
||||
}
|
||||
if(gradientVars != null) {
|
||||
for (String s : gradientVars) {
|
||||
if (getVariable(s).getGradient() != null) {
|
||||
String gradVar = getVariable(s).getGradient().name();
|
||||
outGrads.put(s, grads.get(gradVar));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return out;
|
||||
return new OutAndGrad(outOutputs, outGrads);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
package org.nd4j.autodiff.samediff.api;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* A simple object holding two maps - one of output arrays, another of gradient arrays
|
||||
*/
|
||||
@AllArgsConstructor
|
||||
@Data
|
||||
public class OutAndGrad {
|
||||
|
||||
private final Map<String, INDArray> outputs;
|
||||
private final Map<String, INDArray> gradients;
|
||||
|
||||
}
|
|
@ -797,7 +797,7 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
|||
} else if (v.getVariableType() == VariableType.VARIABLE) {
|
||||
args[i] = v.getArr();
|
||||
} else if (v.isPlaceHolder()) {
|
||||
Preconditions.checkState(placeholderValues != null && placeholderValues.containsKey(s), "No array provided for placeholder %s", s);
|
||||
Preconditions.checkState(placeholderValues != null && placeholderValues.containsKey(s), "No array was provided for required placeholder variable \"%s\"", s);
|
||||
args[i] = placeholderValues.get(s);
|
||||
} else {
|
||||
VarId vid = lookup(s, opInputs, allIterInputs, true);
|
||||
|
|
|
@ -20,6 +20,7 @@ import lombok.EqualsAndHashCode;
|
|||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
|
@ -125,7 +126,9 @@ public class ImagePreProcessingScaler implements DataNormalization {
|
|||
|
||||
@Override
|
||||
public void transformLabel(INDArray label) {
|
||||
//No op
|
||||
Preconditions.checkState(label != null && label.rank() == 4, "Labels can only be transformed for segmentation use" +
|
||||
" cases using this preprocesser - i.e., labels must be rank 4. Got: %ndShape", label);
|
||||
transform(label);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -161,7 +164,9 @@ public class ImagePreProcessingScaler implements DataNormalization {
|
|||
|
||||
@Override
|
||||
public void revertLabels(INDArray labels) {
|
||||
//No op
|
||||
Preconditions.checkState(labels != null && labels.rank() == 4, "Labels can only be transformed for segmentation use" +
|
||||
" cases using this preprocesser - i.e., labels must be rank 4. Got: %ndShape", labels);
|
||||
revertFeatures(labels);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -171,9 +176,7 @@ public class ImagePreProcessingScaler implements DataNormalization {
|
|||
|
||||
@Override
|
||||
public void fitLabel(boolean fitLabels) {
|
||||
if (fitLabels) {
|
||||
log.warn("Labels fitting not currently supported for ImagePreProcessingScaler. Labels will not be modified");
|
||||
}
|
||||
//No-op
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -5831,12 +5831,6 @@ public class Nd4j {
|
|||
}
|
||||
}
|
||||
case UBYTE:
|
||||
UInt8Buffer b = new UInt8Buffer(ArrayUtil.prod(shapeOf));
|
||||
val sb = bb.order(_order).asReadOnlyBuffer();
|
||||
for (int e = 0; e < prod; e++)
|
||||
b.put(e, sb.get(e));
|
||||
|
||||
return Nd4j.create(b, shapeOf);
|
||||
case BFLOAT16:
|
||||
case UINT16:
|
||||
INDArray arr = Nd4j.createUninitialized(_dtype, shapeOf);
|
||||
|
|
|
@ -1000,8 +1000,13 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
val dataType = op.resultType();
|
||||
|
||||
val ret = Nd4j.createUninitialized(dataType, retShape);
|
||||
op.setZ(ret);
|
||||
if( op.z() == null ){
|
||||
val ret = Nd4j.createUninitialized(dataType, retShape);
|
||||
op.setZ(ret);
|
||||
} else if(op.z().dataType() != dataType || !Arrays.equals(retShape, op.z().shape())){
|
||||
throw new ND4JIllegalStateException("Output array for op " + op.getClass().getSimpleName() + " should have type " + dataType + " and shape " + Arrays.toString(retShape)
|
||||
+ " but has datatype " + op.z().dataType() + " and shape " + Arrays.toString(op.z().shape()));
|
||||
}
|
||||
|
||||
val eb = op.extraArgsDataBuff(op.z().dataType() == DataType.BOOL || op.getOpType() == Op.Type.REDUCE_LONG ? op.x().dataType() : op.z().dataType());
|
||||
Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(eb, context) : null;
|
||||
|
|
|
@ -39,6 +39,7 @@ import org.junit.Ignore;
|
|||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.nd4j.OpValidationSuite;
|
||||
import org.nd4j.autodiff.samediff.api.OutAndGrad;
|
||||
import org.nd4j.autodiff.samediff.impl.DefaultSameDiffConditional;
|
||||
import org.nd4j.autodiff.validation.OpValidation;
|
||||
import org.nd4j.autodiff.validation.TestCase;
|
||||
|
@ -3426,4 +3427,30 @@ public class SameDiffTests extends BaseNd4jTest {
|
|||
INDArray a1 = rand1.eval();
|
||||
assertEquals(a0, a1);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testCalculateGradientsAndOutputs(){
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4);
|
||||
SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 4, 3));
|
||||
SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 3));
|
||||
SDVariable z = in.mmul(w).add("z", b);
|
||||
SDVariable softmax = sd.nn.softmax("softmax", z);
|
||||
|
||||
Map<String,INDArray> ph = Collections.singletonMap("in", Nd4j.rand(DataType.FLOAT, 2, 4));
|
||||
List<String> outputs = Arrays.asList("in", "z", "softmax");
|
||||
List<String> grads = Arrays.asList("in", "w", "z");
|
||||
|
||||
OutAndGrad oag = sd.calculateGradientsAndOutputs(ph, outputs, grads);
|
||||
Map<String,INDArray> outs = oag.getOutputs();
|
||||
Map<String,INDArray> g = oag.getGradients();
|
||||
|
||||
|
||||
Map<String,INDArray> outExp = sd.output(ph, outputs);
|
||||
Map<String,INDArray> gExp = sd.calculateGradients(ph, grads);
|
||||
|
||||
assertEquals(outExp, outs);
|
||||
assertEquals(gExp, g);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,22 +19,28 @@ package org.nd4j.autodiff.samediff.listeners;
|
|||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.*;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.autodiff.listeners.Operation;
|
||||
import org.nd4j.autodiff.listeners.*;
|
||||
import org.nd4j.autodiff.listeners.impl.ScoreListener;
|
||||
import org.nd4j.autodiff.listeners.records.History;
|
||||
import org.nd4j.autodiff.listeners.records.LossCurve;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.TrainingConfig;
|
||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||
import org.nd4j.autodiff.samediff.internal.Variable;
|
||||
import org.nd4j.evaluation.classification.Evaluation;
|
||||
import org.nd4j.evaluation.classification.Evaluation.Metric;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.dataset.IrisDataSetIterator;
|
||||
import org.nd4j.linalg.dataset.adapter.SingletonDataSetIterator;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -49,6 +55,11 @@ public class ListenerTest extends BaseNd4jTest {
|
|||
super(backend);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
}
|
||||
|
||||
@Test
|
||||
public void irisHistoryTest() {
|
||||
|
||||
|
@ -112,8 +123,237 @@ public class ListenerTest extends BaseNd4jTest {
|
|||
assertTrue("Accuracy < 75%, was " + acc, acc >= 0.75);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
@Test
|
||||
public void testListenerCalls(){
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4);
|
||||
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 3);
|
||||
SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 4, 3));
|
||||
SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 3));
|
||||
SDVariable z = in.mmul(w).add(b);
|
||||
SDVariable softmax = sd.nn.softmax("softmax", z);
|
||||
SDVariable loss = sd.loss.logLoss("loss" ,label, softmax);
|
||||
|
||||
TestListener tl = new TestListener(Operation.INFERENCE);
|
||||
sd.setListeners(tl);
|
||||
|
||||
//Check listener called during inference
|
||||
Map<String,INDArray> phMap = Collections.singletonMap("in", Nd4j.rand(1, 4));
|
||||
|
||||
for( int i=1; i<=5; i++ ) {
|
||||
INDArray out = sd.outputSingle(phMap, "softmax");
|
||||
|
||||
assertEquals(0, tl.epochStartCount);
|
||||
assertEquals(0, tl.epochEndCount);
|
||||
assertEquals(0, tl.validationDoneCount);
|
||||
assertEquals(0, tl.iterationStartCount);
|
||||
assertEquals(0, tl.iterationDoneCount);
|
||||
assertEquals(Collections.singletonMap(Operation.INFERENCE, i), tl.operationStartCount);
|
||||
assertEquals(Collections.singletonMap(Operation.INFERENCE, i), tl.operationEndCount);
|
||||
assertEquals(3*i, tl.preOpExecutionCount); //mmul, add, softmax
|
||||
assertEquals(3*i, tl.opExecutionCount);
|
||||
assertEquals(3*i, tl.activationAvailableCount); //mmul, add, softmax outputs
|
||||
assertEquals(0, tl.preUpdateCount); //Inference -> no updating
|
||||
}
|
||||
|
||||
//Check listener NOT called during inference when set to Operation.TRAINING
|
||||
tl = new TestListener(Operation.TRAINING);
|
||||
sd.setListeners(tl);
|
||||
sd.outputSingle(phMap, "softmax");
|
||||
|
||||
assertEquals(0, tl.epochStartCount);
|
||||
assertEquals(0, tl.epochEndCount);
|
||||
assertEquals(0, tl.validationDoneCount);
|
||||
assertEquals(0, tl.iterationStartCount);
|
||||
assertEquals(0, tl.iterationDoneCount);
|
||||
assertEquals(Collections.emptyMap(), tl.operationStartCount);
|
||||
assertEquals(Collections.emptyMap(), tl.operationEndCount);
|
||||
assertEquals(0, tl.preOpExecutionCount);
|
||||
assertEquals(0, tl.opExecutionCount);
|
||||
assertEquals(0, tl.activationAvailableCount);
|
||||
assertEquals(0, tl.preUpdateCount);
|
||||
|
||||
//Check listener called during gradient calculation
|
||||
tl = new TestListener(Operation.TRAINING);
|
||||
sd.setListeners(tl);
|
||||
phMap = new HashMap<>();
|
||||
phMap.put("in", Nd4j.rand( DataType.FLOAT, 1, 4));
|
||||
phMap.put("label", Nd4j.createFromArray(0f, 1f, 0f).reshape(1, 3));
|
||||
|
||||
for( int i=1; i<=3; i++ ) {
|
||||
sd.calculateGradients(phMap, "in", "w", "b");
|
||||
assertEquals(0, tl.epochStartCount);
|
||||
assertEquals(0, tl.epochEndCount);
|
||||
assertEquals(0, tl.validationDoneCount);
|
||||
assertEquals(0, tl.iterationStartCount);
|
||||
assertEquals(0, tl.iterationDoneCount);
|
||||
assertEquals(Collections.singletonMap(Operation.TRAINING, i), tl.operationStartCount);
|
||||
assertEquals(Collections.singletonMap(Operation.TRAINING, i), tl.operationEndCount);
|
||||
assertEquals(7*i, tl.preOpExecutionCount); //mmul, add, softmax, loss grad, softmax backward, add backward, mmul backward
|
||||
assertEquals(7*i, tl.opExecutionCount);
|
||||
assertEquals(11*i, tl.activationAvailableCount); //mmul, add, softmax, loss grad (weight, in, label), softmax bp, add backward (z, b), mmul (in, w)
|
||||
assertEquals(0, tl.preUpdateCount);
|
||||
}
|
||||
|
||||
|
||||
//Check listener NOT called during gradient calculation - when listener is still set to INFERENCE mode
|
||||
tl = new TestListener(Operation.INFERENCE);
|
||||
sd.setListeners(tl);
|
||||
for( int i=1; i<=3; i++ ) {
|
||||
sd.calculateGradients(phMap, "in", "w", "b");
|
||||
assertEquals(0, tl.epochStartCount);
|
||||
assertEquals(0, tl.epochEndCount);
|
||||
assertEquals(0, tl.validationDoneCount);
|
||||
assertEquals(0, tl.iterationStartCount);
|
||||
assertEquals(0, tl.iterationDoneCount);
|
||||
assertEquals(Collections.emptyMap(), tl.operationStartCount);
|
||||
assertEquals(Collections.emptyMap(), tl.operationEndCount);
|
||||
assertEquals(0, tl.preOpExecutionCount);
|
||||
assertEquals(0, tl.opExecutionCount);
|
||||
assertEquals(0, tl.activationAvailableCount);
|
||||
assertEquals(0, tl.preUpdateCount);
|
||||
}
|
||||
|
||||
//Check fit:
|
||||
tl = new TestListener(Operation.TRAINING);
|
||||
sd.setListeners(tl);
|
||||
sd.setTrainingConfig(TrainingConfig.builder()
|
||||
.dataSetFeatureMapping("in")
|
||||
.dataSetLabelMapping("label")
|
||||
.updater(new Adam(1e-3))
|
||||
.build());
|
||||
|
||||
SingletonDataSetIterator dsi = new SingletonDataSetIterator(new DataSet(phMap.get("in"), phMap.get("label")));
|
||||
for( int i=1; i<=3; i++ ) {
|
||||
sd.fit(dsi, 1);
|
||||
assertEquals(i, tl.epochStartCount);
|
||||
assertEquals(i, tl.epochEndCount);
|
||||
assertEquals(0, tl.validationDoneCount);
|
||||
assertEquals(i, tl.iterationStartCount);
|
||||
assertEquals(i, tl.iterationDoneCount);
|
||||
assertEquals(Collections.singletonMap(Operation.TRAINING, i), tl.operationStartCount);
|
||||
assertEquals(Collections.singletonMap(Operation.TRAINING, i), tl.operationEndCount);
|
||||
assertEquals(7*i, tl.preOpExecutionCount); //mmul, add, softmax, loss grad, softmax backward, add backward, mmul backward
|
||||
assertEquals(7*i, tl.opExecutionCount);
|
||||
assertEquals(11*i, tl.activationAvailableCount); //mmul, add, softmax, loss grad (weight, in, label), softmax bp, add backward (z, b), mmul (in, w)
|
||||
assertEquals(2*i, tl.preUpdateCount); //w, b
|
||||
}
|
||||
|
||||
|
||||
//Check evaluation:
|
||||
tl = new TestListener(Operation.EVALUATION);
|
||||
sd.setListeners(tl);
|
||||
|
||||
for( int i=1; i<=3; i++ ) {
|
||||
sd.evaluate(dsi, "softmax", new Evaluation());
|
||||
assertEquals(0, tl.epochStartCount);
|
||||
assertEquals(0, tl.epochEndCount);
|
||||
assertEquals(0, tl.validationDoneCount);
|
||||
assertEquals(0, tl.iterationStartCount);
|
||||
assertEquals(0, tl.iterationDoneCount);
|
||||
assertEquals(Collections.singletonMap(Operation.EVALUATION, i), tl.operationStartCount);
|
||||
assertEquals(Collections.singletonMap(Operation.EVALUATION, i), tl.operationEndCount);
|
||||
assertEquals(3*i, tl.preOpExecutionCount); //mmul, add, softmax
|
||||
assertEquals(3*i, tl.opExecutionCount);
|
||||
assertEquals(3*i, tl.activationAvailableCount); //mmul, add, softmax
|
||||
assertEquals(0, tl.preUpdateCount); //w, b
|
||||
}
|
||||
}
|
||||
|
||||
private static class TestListener implements Listener {
|
||||
|
||||
public TestListener(Operation operation){
|
||||
this.operation = operation;
|
||||
}
|
||||
|
||||
private final Operation operation;
|
||||
|
||||
private int epochStartCount = 0;
|
||||
private int epochEndCount = 0;
|
||||
private int validationDoneCount = 0;
|
||||
private int iterationStartCount = 0;
|
||||
private int iterationDoneCount = 0;
|
||||
private Map<Operation,Integer> operationStartCount = new HashMap<>();
|
||||
private Map<Operation,Integer> operationEndCount = new HashMap<>();
|
||||
private int preOpExecutionCount = 0;
|
||||
private int opExecutionCount = 0;
|
||||
private int activationAvailableCount = 0;
|
||||
private int preUpdateCount = 0;
|
||||
|
||||
|
||||
@Override
|
||||
public ListenerVariables requiredVariables(SameDiff sd) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isActive(Operation operation) {
|
||||
return this.operation == null || this.operation == operation;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void epochStart(SameDiff sd, At at) {
|
||||
epochStartCount++;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ListenerResponse epochEnd(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis) {
|
||||
epochEndCount++;
|
||||
return ListenerResponse.CONTINUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ListenerResponse validationDone(SameDiff sd, At at, long validationTimeMillis) {
|
||||
validationDoneCount++;
|
||||
return ListenerResponse.CONTINUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void iterationStart(SameDiff sd, At at, MultiDataSet data, long etlTimeMs) {
|
||||
iterationStartCount++;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void iterationDone(SameDiff sd, At at, MultiDataSet dataSet, Loss loss) {
|
||||
iterationDoneCount++;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void operationStart(SameDiff sd, Operation op) {
|
||||
if(!operationStartCount.containsKey(op)) {
|
||||
operationStartCount.put(op, 1);
|
||||
} else {
|
||||
operationStartCount.put(op, operationStartCount.get(op) + 1);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void operationEnd(SameDiff sd, Operation op) {
|
||||
if(!operationEndCount.containsKey(op)) {
|
||||
operationEndCount.put(op, 1);
|
||||
} else {
|
||||
operationEndCount.put(op, operationEndCount.get(op) + 1);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void preOpExecution(SameDiff sd, At at, SameDiffOp op) {
|
||||
preOpExecutionCount++;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
|
||||
opExecutionCount++;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void activationAvailable(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, String varName, INDArray activation) {
|
||||
activationAvailableCount++;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void preUpdate(SameDiff sd, At at, Variable v, INDArray update) {
|
||||
preUpdateCount++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,11 +20,13 @@ import org.junit.Test;
|
|||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.preprocessor.ImageMultiPreProcessingScaler;
|
||||
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
@ -42,32 +44,32 @@ public class ImagePreProcessortTest extends BaseNd4jTest {
|
|||
|
||||
@Test
|
||||
public void simpleImageTest() {
|
||||
INDArray rChannels = Nd4j.zeros(10, 10).addi(128);
|
||||
INDArray gChannels = Nd4j.zeros(10, 10).addi(64);
|
||||
INDArray bChannels = Nd4j.zeros(10, 10).addi(255);
|
||||
INDArray image = Nd4j.vstack(rChannels, gChannels, bChannels).reshape(3, 10, 10);
|
||||
INDArray rChannels = Nd4j.zeros(DataType.FLOAT, 10, 10).addi(128);
|
||||
INDArray gChannels = Nd4j.zeros(DataType.FLOAT, 10, 10).addi(64);
|
||||
INDArray bChannels = Nd4j.zeros(DataType.FLOAT, 10, 10).addi(255);
|
||||
INDArray image = Nd4j.vstack(rChannels, gChannels, bChannels).reshape(1, 3, 10, 10);
|
||||
INDArray orig = image.dup();
|
||||
|
||||
//System.out.println(Arrays.toString(image.shape()));
|
||||
DataSet ds = new DataSet(image.reshape(1, 3, 10, 10), Nd4j.ones(1, 1));
|
||||
DataSet ds = new DataSet(image, Nd4j.ones(1, 1));
|
||||
ImagePreProcessingScaler myScaler = new ImagePreProcessingScaler();
|
||||
//So this should scale to 0.5,0.25 and 1;
|
||||
INDArray expected = image.mul(0);
|
||||
expected.slice(0, 0).addi(0.5);
|
||||
expected.slice(1, 0).addi(0.25);
|
||||
expected.slice(2, 0).addi(1.0);
|
||||
expected.slice(0, 1).addi(0.5);
|
||||
expected.slice(1, 1).addi(0.25);
|
||||
expected.slice(2, 1).addi(1.0);
|
||||
myScaler.transform(ds);
|
||||
assertTrue(Transforms.abs(ds.getFeatures().sub(expected)).maxNumber().doubleValue() <= 0.01);
|
||||
|
||||
//Now giving it 16 bits instead of the default
|
||||
//System.out.println(Arrays.toString(image.shape()));
|
||||
ds = new DataSet(image.reshape(1, 3, 10, 10), Nd4j.ones(1, 1));
|
||||
ds = new DataSet(image, Nd4j.ones(1, 1));
|
||||
myScaler = new ImagePreProcessingScaler(0, 1, 16);
|
||||
//So this should scale to 0.5,0.25 and 1;
|
||||
expected = image.mul(0);
|
||||
expected.slice(0, 0).addi(0.5 / 256);
|
||||
expected.slice(1, 0).addi(0.25 / 256);
|
||||
expected.slice(2, 0).addi(1.0 / 256);
|
||||
expected.slice(0, 1).addi(0.5 / 256);
|
||||
expected.slice(1, 1).addi(0.25 / 256);
|
||||
expected.slice(2, 1).addi(1.0 / 256);
|
||||
myScaler.transform(ds);
|
||||
assertTrue(Transforms.abs(ds.getFeatures().sub(expected)).maxNumber().doubleValue() <= 0.01);
|
||||
|
||||
|
@ -88,6 +90,16 @@ public class ImagePreProcessortTest extends BaseNd4jTest {
|
|||
myScaler.transform(before);
|
||||
myScaler.revertFeatures(before);
|
||||
assertEquals(orig, before);
|
||||
|
||||
|
||||
//Test labels (segmentation case)
|
||||
before = orig.dup();
|
||||
myScaler = new ImagePreProcessingScaler(0, 1);
|
||||
myScaler.transformLabel(before);
|
||||
expected = orig.div(255);
|
||||
assertEquals(expected, before);
|
||||
myScaler.revertLabels(before);
|
||||
assertEquals(orig, before);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
Loading…
Reference in New Issue