Various SameDiff fixes (#21)
* MKLDNN LSTM forward implementation (disabled pending #8331) Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8318 add SameDiff.calculateGradientsAndOutputs Signed-off-by: AlexDBlack <blacka101@gmail.com> * Disable mkldnn backprop for now - pending fix, issue #8335 Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8337 Fix CudaExecutioner unnecessary result array allocation/replacement Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small FlatBuffers serde fix, UInt8 Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8135 ImagePreProcessingScaler - add segmentation support Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8319 Ensure listeners are called when they are supposed to be called Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8214 UNet (non-pretrained) last conv layer kernal size fix Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
b816845797
commit
d82877b18b
|
@ -0,0 +1,107 @@
|
||||||
|
package org.deeplearning4j;
|
||||||
|
|
||||||
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.WorkspaceMode;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
|
||||||
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.nn.layers.mkldnn.MKLDNNBatchNormHelper;
|
||||||
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
|
import java.lang.reflect.Field;
|
||||||
|
|
||||||
|
import static junit.framework.TestCase.*;
|
||||||
|
|
||||||
|
public class TestBatchNormBp {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void test(){
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
// INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 4, 4);
|
||||||
|
INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 15, 15);
|
||||||
|
INDArray mean = in.mean(0, 2, 3); //Nd4j.rand(DataType.FLOAT, 3);
|
||||||
|
INDArray var = in.var(0, 2, 3); //Nd4j.rand(DataType.FLOAT, 3);
|
||||||
|
INDArray eps = Nd4j.rand(DataType.FLOAT, in.shape());
|
||||||
|
// INDArray gamma = Nd4j.ones(DataType.FLOAT, 3);
|
||||||
|
// INDArray beta = Nd4j.zeros(DataType.FLOAT, 3);
|
||||||
|
INDArray gamma = Nd4j.rand(DataType.FLOAT, 3);
|
||||||
|
INDArray beta = Nd4j.rand(DataType.FLOAT, 3);
|
||||||
|
double e = 1e-5;
|
||||||
|
|
||||||
|
INDArray dLdIn = in.ulike();
|
||||||
|
INDArray dLdm = mean.ulike();
|
||||||
|
INDArray dLdv = var.ulike();
|
||||||
|
INDArray dLdg = gamma.ulike();
|
||||||
|
INDArray dLdb = beta.ulike();
|
||||||
|
|
||||||
|
DynamicCustomOp op = DynamicCustomOp.builder("batchnorm_bp")
|
||||||
|
.addInputs(in, mean, var, eps, gamma, beta)
|
||||||
|
.addIntegerArguments(
|
||||||
|
1, //Apply scale
|
||||||
|
1, //Apply beta
|
||||||
|
1) //Axis (NCHW)
|
||||||
|
.addFloatingPointArguments(e)
|
||||||
|
.addOutputs(dLdIn, dLdm, dLdv, dLdg, dLdb)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
Nd4j.exec(op);
|
||||||
|
System.out.println(dLdIn);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void compareImpls() throws Exception {
|
||||||
|
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 15, 15);
|
||||||
|
INDArray mean = in.mean(0, 2, 3).reshape(1,3);
|
||||||
|
INDArray var = in.var(0, 2, 3).reshape(1,3);
|
||||||
|
INDArray eps = Nd4j.rand(DataType.FLOAT, in.shape());
|
||||||
|
INDArray gamma = Nd4j.rand(DataType.FLOAT, 1,3);
|
||||||
|
INDArray beta = Nd4j.rand(DataType.FLOAT, 1,3);
|
||||||
|
double e = 1e-3;
|
||||||
|
|
||||||
|
INDArray dLdIn = in.ulike();
|
||||||
|
INDArray dLdm = mean.ulike();
|
||||||
|
INDArray dLdv = var.ulike();
|
||||||
|
INDArray dLdg = gamma.ulike();
|
||||||
|
INDArray dLdb = beta.ulike();
|
||||||
|
|
||||||
|
|
||||||
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.inferenceWorkspaceMode(WorkspaceMode.NONE)
|
||||||
|
.trainingWorkspaceMode(WorkspaceMode.NONE)
|
||||||
|
.list()
|
||||||
|
.layer(new BatchNormalization.Builder().nIn(3).nOut(3).build())
|
||||||
|
.build();
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
net.init();
|
||||||
|
org.deeplearning4j.nn.layers.normalization.BatchNormalization bn = (org.deeplearning4j.nn.layers.normalization.BatchNormalization) net.getLayer(0);
|
||||||
|
assertNotNull(bn.getHelper());
|
||||||
|
Field f = bn.getClass().getDeclaredField("helper");
|
||||||
|
f.setAccessible(true);
|
||||||
|
f.set(bn, null);
|
||||||
|
assertNull(bn.getHelper());
|
||||||
|
|
||||||
|
|
||||||
|
MKLDNNBatchNormHelper h = new MKLDNNBatchNormHelper(DataType.FLOAT);
|
||||||
|
|
||||||
|
net.output(in, true);
|
||||||
|
bn.setInput(in, LayerWorkspaceMgr.noWorkspaces());
|
||||||
|
Pair<Gradient,INDArray> p = net.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces());
|
||||||
|
|
||||||
|
h.preOutput(in, true, new int[]{1,3}, gamma, beta, mean, var, 0.5, e, LayerWorkspaceMgr.noWorkspaces());
|
||||||
|
Pair<Gradient,INDArray> pmkl = h.backpropGradient(in, eps, new int[]{1,3}, gamma, beta, dLdg, dLdb, e, LayerWorkspaceMgr.noWorkspaces());
|
||||||
|
|
||||||
|
INDArray dldin_dl4j = p.getSecond();
|
||||||
|
|
||||||
|
System.out.println("dl4j == mkldnn: " + p.getSecond().equals(pmkl.getSecond()));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -23,10 +23,13 @@ import org.deeplearning4j.datasets.iterator.impl.SingletonDataSetIterator;
|
||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.WorkspaceMode;
|
||||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.junit.Ignore;
|
import org.junit.Ignore;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
@ -36,10 +39,13 @@ import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.learning.config.Adam;
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
|
import java.lang.reflect.Field;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
|
||||||
|
import static junit.framework.TestCase.*;
|
||||||
import static org.junit.Assume.assumeTrue;
|
import static org.junit.Assume.assumeTrue;
|
||||||
|
|
||||||
public class ValidateMKLDNN extends BaseDL4JTest {
|
public class ValidateMKLDNN extends BaseDL4JTest {
|
||||||
|
@ -148,7 +154,7 @@ public class ValidateMKLDNN extends BaseDL4JTest {
|
||||||
.padding(0, 0)
|
.padding(0, 0)
|
||||||
.nOut(3)
|
.nOut(3)
|
||||||
.build())
|
.build())
|
||||||
.layer(new BatchNormalization.Builder().cudnnAllowFallback(false).build())
|
.layer(new BatchNormalization.Builder().helperAllowFallback(false)/*.eps(0)*/.build())
|
||||||
.layer(new ConvolutionLayer.Builder().activation(Activation.TANH)
|
.layer(new ConvolutionLayer.Builder().activation(Activation.TANH)
|
||||||
.kernelSize(kernel)
|
.kernelSize(kernel)
|
||||||
.stride(stride)
|
.stride(stride)
|
||||||
|
@ -256,4 +262,54 @@ public class ValidateMKLDNN extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void compareBatchNormBackward() throws Exception {
|
||||||
|
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
INDArray in = Nd4j.rand(DataType.FLOAT, 1, 3, 15, 15);
|
||||||
|
INDArray mean = in.mean(0, 2, 3).reshape(1,3);
|
||||||
|
INDArray var = in.var(0, 2, 3).reshape(1,3);
|
||||||
|
INDArray eps = Nd4j.rand(DataType.FLOAT, in.shape());
|
||||||
|
INDArray gamma = Nd4j.rand(DataType.FLOAT, 1,3);
|
||||||
|
INDArray beta = Nd4j.rand(DataType.FLOAT, 1,3);
|
||||||
|
double e = 1e-3;
|
||||||
|
|
||||||
|
INDArray dLdIn = in.ulike();
|
||||||
|
INDArray dLdm = mean.ulike();
|
||||||
|
INDArray dLdv = var.ulike();
|
||||||
|
INDArray dLdg = gamma.ulike();
|
||||||
|
INDArray dLdb = beta.ulike();
|
||||||
|
|
||||||
|
|
||||||
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.inferenceWorkspaceMode(WorkspaceMode.NONE)
|
||||||
|
.trainingWorkspaceMode(WorkspaceMode.NONE)
|
||||||
|
.list()
|
||||||
|
.layer(new BatchNormalization.Builder().nIn(3).nOut(3).build())
|
||||||
|
.build();
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
net.init();
|
||||||
|
org.deeplearning4j.nn.layers.normalization.BatchNormalization bn = (org.deeplearning4j.nn.layers.normalization.BatchNormalization) net.getLayer(0);
|
||||||
|
assertNotNull(bn.getHelper());
|
||||||
|
System.out.println(bn.getHelper());
|
||||||
|
|
||||||
|
net.output(in, true);
|
||||||
|
bn.setInput(in, LayerWorkspaceMgr.noWorkspaces());
|
||||||
|
Pair<Gradient,INDArray> pcudnn = net.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces());
|
||||||
|
|
||||||
|
Field f = bn.getClass().getDeclaredField("helper");
|
||||||
|
f.setAccessible(true);
|
||||||
|
f.set(bn, null);
|
||||||
|
assertNull(bn.getHelper());
|
||||||
|
|
||||||
|
net.output(in, true);
|
||||||
|
bn.setInput(in, LayerWorkspaceMgr.noWorkspaces());
|
||||||
|
Pair<Gradient,INDArray> p = net.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces());
|
||||||
|
|
||||||
|
INDArray dldin_dl4j = p.getSecond();
|
||||||
|
INDArray dldin_helper = pcudnn.getSecond();
|
||||||
|
|
||||||
|
assertTrue(dldin_dl4j.equalsWithEps(dldin_helper, 1e-5));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -123,7 +123,7 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, int[] shape, INDArray gamma,
|
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, int[] shape, INDArray gamma, INDArray beta,
|
||||||
INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr layerWorkspaceMgr) {
|
INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr layerWorkspaceMgr) {
|
||||||
this.eps = eps;
|
this.eps = eps;
|
||||||
val miniBatch = (int) input.size(0);
|
val miniBatch = (int) input.size(0);
|
||||||
|
@ -189,7 +189,7 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
|
||||||
Pointer varCacheData = allocator.getPointer(varCache, context);
|
Pointer varCacheData = allocator.getPointer(varCache, context);
|
||||||
|
|
||||||
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
|
checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
|
||||||
checkCudnn(cudnnBatchNormalizationBackward(cudnnContext, batchNormMode, alpha, beta, alpha, alpha,
|
checkCudnn(cudnnBatchNormalizationBackward(cudnnContext, batchNormMode, alpha, this.beta, alpha, alpha,
|
||||||
cudnnContext.srcTensorDesc, srcData, cudnnContext.deltaTensorDesc, epsData,
|
cudnnContext.srcTensorDesc, srcData, cudnnContext.deltaTensorDesc, epsData,
|
||||||
cudnnContext.dstTensorDesc, dstData, cudnnContext.gammaBetaTensorDesc, gammaData, dGammaData,
|
cudnnContext.dstTensorDesc, dstData, cudnnContext.gammaBetaTensorDesc, gammaData, dGammaData,
|
||||||
dBetaData, eps, meanCacheData, varCacheData));
|
dBetaData, eps, meanCacheData, varCacheData));
|
||||||
|
|
|
@ -16,21 +16,28 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.layers.mkldnn;
|
package org.deeplearning4j.nn.layers.mkldnn;
|
||||||
|
|
||||||
|
import org.deeplearning4j.nn.gradient.DefaultGradient;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper;
|
import org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper;
|
||||||
|
import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer;
|
||||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.OpContext;
|
import org.nd4j.linalg.api.ops.OpContext;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNormDerivative;
|
||||||
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
|
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
|
||||||
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -57,27 +64,53 @@ public class MKLDNNBatchNormHelper implements BatchNormalizationHelper {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma,
|
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma,
|
||||||
INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr workspaceMgr) {
|
INDArray beta, INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr workspaceMgr) {
|
||||||
//2019-02-14: Backprop disabled pending fixes. https://github.com/deeplearning4j/deeplearning4j/issues/7166
|
if(input.dataType() != DataType.FLOAT)
|
||||||
//Also no MKL-DNN implemented for backprop anyway
|
return null; //MKL-DNN only supports float
|
||||||
/*
|
/*
|
||||||
INDArray[] in = gamma == null ? new INDArray[]{input, mean, var, epsilon} : new INDArray[]{input, mean, var, gamma, beta, epsilon};
|
//TODO FIXME - AB 2019/11/01 - https://github.com/eclipse/deeplearning4j/issues/8335
|
||||||
|
List<INDArray> args = new ArrayList<>();
|
||||||
|
args.add(input);
|
||||||
|
args.add(meanCache);
|
||||||
|
args.add(varCache);
|
||||||
|
args.add(epsilon);
|
||||||
|
if(gamma != null)
|
||||||
|
args.add(gamma.reshape(gamma.length()));
|
||||||
|
if(beta != null)
|
||||||
|
args.add(beta.reshape(beta.length()));
|
||||||
|
|
||||||
INDArray gradAtInput = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), input.shape());
|
|
||||||
|
|
||||||
INDArray[] out = gamma == null ? new INDArray[]{gradAtInput, }
|
DynamicCustomOp op = DynamicCustomOp.builder("batchnorm_bp")
|
||||||
|
.addInputs(args.toArray(new INDArray[0]))
|
||||||
BatchNormDerivative bn = BatchNormDerivative.derivativeBuilder()
|
.addIntegerArguments(
|
||||||
.applyBeta(gamma != null)
|
gamma == null ? 0 : 1, //Apply scale
|
||||||
.applyGamma(gamma != null)
|
beta == null ? 0 : 1, //Apply beta
|
||||||
.axis(new int[]{1}) //4d: is channels: NCHW; 2d: is nIn - axis 1 in both cases
|
1) //Axis (NCHW)
|
||||||
.epsilon(eps)
|
.addFloatingPointArguments(eps)
|
||||||
.inputArrays(in)
|
|
||||||
.outputArrays(new INDArray[]{out})
|
|
||||||
.build();
|
.build();
|
||||||
Nd4j.exec(bn);
|
|
||||||
*/
|
|
||||||
|
|
||||||
|
INDArray epsAtInput = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape());
|
||||||
|
INDArray dLdm = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, meanCache.dataType(), meanCache.shape());
|
||||||
|
INDArray dLdv = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, meanCache.dataType(), meanCache.shape());
|
||||||
|
|
||||||
|
op.setOutputArgument(0, epsAtInput);
|
||||||
|
op.setOutputArgument(1, dLdm);
|
||||||
|
op.setOutputArgument(2, dLdv);
|
||||||
|
if(dGammaView != null) {
|
||||||
|
//Both are always null/not null simultaneously
|
||||||
|
op.setOutputArgument(3, dGammaView.reshape(dGammaView.length()));
|
||||||
|
op.setOutputArgument(4, dBetaView.reshape(dBetaView.length()));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Nd4j.exec(op);
|
||||||
|
|
||||||
|
Gradient g = new DefaultGradient();
|
||||||
|
g.setGradientFor(BatchNormalizationParamInitializer.GAMMA, dGammaView);
|
||||||
|
g.setGradientFor(BatchNormalizationParamInitializer.BETA, dBetaView);
|
||||||
|
|
||||||
|
return new Pair<>(g, epsAtInput);
|
||||||
|
*/
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,168 @@
|
||||||
|
package org.deeplearning4j.nn.layers.mkldnn;
|
||||||
|
|
||||||
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.LSTM;
|
||||||
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.nn.layers.recurrent.FwdPassReturn;
|
||||||
|
import org.deeplearning4j.nn.layers.recurrent.LSTMHelper;
|
||||||
|
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||||
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
|
import org.nd4j.linalg.activations.IActivation;
|
||||||
|
import org.nd4j.linalg.activations.impl.*;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.indexing.BooleanIndexing;
|
||||||
|
import org.nd4j.linalg.indexing.conditions.Conditions;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class MKLDNNLSTMHelper implements LSTMHelper {
|
||||||
|
@Override
|
||||||
|
public boolean checkSupported(IActivation gateActivationFn, IActivation activationFn, boolean hasPeepholeConnections) {
|
||||||
|
//TODO check other activation functions for MKLDNN
|
||||||
|
return gateActivationFn instanceof ActivationSigmoid && activationFn instanceof ActivationTanH && BaseMKLDNNHelper.mklDnnEnabled();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Pair<Gradient, INDArray> backpropGradient(NeuralNetConfiguration conf, IActivation gateActivationFn, INDArray input,
|
||||||
|
INDArray recurrentWeights, INDArray inputWeights, INDArray epsilon, boolean truncatedBPTT,
|
||||||
|
int tbpttBackwardLength, FwdPassReturn fwdPass, boolean forwards, String inputWeightKey,
|
||||||
|
String recurrentWeightKey, String biasWeightKey, Map<String, INDArray> gradientViews,
|
||||||
|
INDArray maskArray, boolean hasPeepholeConnections, LayerWorkspaceMgr workspaceMgr) {
|
||||||
|
//Not yet implemented/supported
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FwdPassReturn activate(Layer layer, NeuralNetConfiguration conf, IActivation gateActivationFn, INDArray input,
|
||||||
|
INDArray recurrentWeights, INDArray inputWeights, INDArray biases, boolean training,
|
||||||
|
INDArray prevOutputActivations, INDArray prevMemCellState, boolean forBackprop, boolean forwards,
|
||||||
|
String inputWeightKey, INDArray maskArray, boolean hasPeepholeConnections, LayerWorkspaceMgr workspaceMgr) {
|
||||||
|
|
||||||
|
/*
|
||||||
|
DL4J data format: [bS, nIn, sL] - dataFormat == 2, directionMode == 0 (forward)
|
||||||
|
Inputs:
|
||||||
|
x = [bS, nIn, sL]
|
||||||
|
Wx = [nIn, 4*nOut]
|
||||||
|
Wr = [nOut, 4*nOut]
|
||||||
|
Wp = [3*nOut] Optional peephole weights
|
||||||
|
b = [4*nOut]
|
||||||
|
seqLen = [bS]
|
||||||
|
initialOut = [bs, nOut]
|
||||||
|
initialCell = [bs, nOut]
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
out = [bS, nOut, sL]
|
||||||
|
outLast = [bs, nOut]
|
||||||
|
cellLast = [bs,nOut]
|
||||||
|
|
||||||
|
Gates order: input, forget, input modulation, output
|
||||||
|
|
||||||
|
|
||||||
|
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
|
||||||
|
const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided
|
||||||
|
const auto hasInitH = B_ARG(2); // indicates whether initial output is provided
|
||||||
|
const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided
|
||||||
|
const auto hasPH = B_ARG(4); // indicates whether peephole connections are present
|
||||||
|
const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1}
|
||||||
|
const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
|
||||||
|
const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
|
||||||
|
*/
|
||||||
|
|
||||||
|
INDArray b1d = biases.reshape(biases.length());
|
||||||
|
INDArray seqLen = null;
|
||||||
|
if(maskArray != null){
|
||||||
|
seqLen = BooleanIndexing.firstIndex(maskArray, Conditions.equals(0), 1); //First 0 along dimension 1 (for [mb, seqLen])
|
||||||
|
}
|
||||||
|
|
||||||
|
List<INDArray> args = new ArrayList<>();
|
||||||
|
args.add(input);
|
||||||
|
args.add(inputWeights);
|
||||||
|
args.add(recurrentWeights);
|
||||||
|
if(hasPeepholeConnections){
|
||||||
|
throw new IllegalStateException("Not yet implemented");
|
||||||
|
}
|
||||||
|
args.add(b1d);
|
||||||
|
if(seqLen != null)
|
||||||
|
args.add(seqLen);
|
||||||
|
if(prevOutputActivations != null)
|
||||||
|
args.add(prevOutputActivations);
|
||||||
|
if(prevMemCellState != null)
|
||||||
|
args.add(prevMemCellState);
|
||||||
|
|
||||||
|
IActivation a = ((LSTM)conf.getLayer()).getActivationFn();
|
||||||
|
|
||||||
|
DynamicCustomOp op = DynamicCustomOp.builder("lstmLayer")
|
||||||
|
.addInputs(args.toArray(new INDArray[0]))
|
||||||
|
.addBooleanArguments(
|
||||||
|
true, //hasBiases
|
||||||
|
seqLen != null, //hasSeqLen
|
||||||
|
prevOutputActivations != null, //hasInitH
|
||||||
|
prevMemCellState != null, //hasInitC
|
||||||
|
hasPeepholeConnections, //hasPh
|
||||||
|
true, //retFullSeq
|
||||||
|
true, //retLastH
|
||||||
|
true //retLastC
|
||||||
|
)
|
||||||
|
.addIntegerArguments(
|
||||||
|
2, //data format: 2 = [bS, nIn, sL]
|
||||||
|
0, //direction: 0 = forward
|
||||||
|
activationToArg(gateActivationFn), //Gate activation
|
||||||
|
activationToArg(a), //Cell state activation
|
||||||
|
activationToArg(a) //Output activation (same as cell in DL4J)
|
||||||
|
)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
List<LongShapeDescriptor> outShapes = op.calculateOutputShape();
|
||||||
|
|
||||||
|
for(LongShapeDescriptor lsd : outShapes){
|
||||||
|
INDArray arr = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, lsd.dataType(), lsd.getShape(), lsd.getOrder());
|
||||||
|
op.addOutputArgument(arr);
|
||||||
|
}
|
||||||
|
|
||||||
|
FwdPassReturn f = new FwdPassReturn();
|
||||||
|
f.fwdPassOutput = op.getOutputArgument(0);
|
||||||
|
f.lastAct = op.getOutputArgument(1);
|
||||||
|
f.lastMemCell = op.getOutputArgument(2);
|
||||||
|
|
||||||
|
return f;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<String, Long> helperMemoryUse() {
|
||||||
|
return Collections.emptyMap();
|
||||||
|
}
|
||||||
|
|
||||||
|
private int activationToArg(IActivation a){
|
||||||
|
//0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
|
||||||
|
if(a instanceof ActivationTanH)
|
||||||
|
return 0;
|
||||||
|
if(a instanceof ActivationReLU)
|
||||||
|
return 1;
|
||||||
|
if(a instanceof ActivationSigmoid)
|
||||||
|
return 2;
|
||||||
|
if(a instanceof ActivationIdentity)
|
||||||
|
return 3;
|
||||||
|
if(a instanceof ActivationLReLU)
|
||||||
|
return 4;
|
||||||
|
if(a instanceof ActivationThresholdedReLU)
|
||||||
|
return 5;
|
||||||
|
if(a instanceof ActivationHardSigmoid)
|
||||||
|
return 7;
|
||||||
|
if(a instanceof ActivationELU)
|
||||||
|
return 8;
|
||||||
|
if(a instanceof ActivationSoftSign)
|
||||||
|
return 9;
|
||||||
|
if(a instanceof ActivationSoftPlus)
|
||||||
|
return 10;
|
||||||
|
throw new IllegalStateException("Unknown or not supported activation function: " + a);
|
||||||
|
}
|
||||||
|
}
|
|
@ -118,6 +118,7 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
|
||||||
INDArray globalVar = params.get(BatchNormalizationParamInitializer.GLOBAL_VAR); //One of log10std will be null depending on config
|
INDArray globalVar = params.get(BatchNormalizationParamInitializer.GLOBAL_VAR); //One of log10std will be null depending on config
|
||||||
INDArray globalLog10Std = params.get(BatchNormalizationParamInitializer.GLOBAL_LOG_STD);
|
INDArray globalLog10Std = params.get(BatchNormalizationParamInitializer.GLOBAL_LOG_STD);
|
||||||
INDArray gamma = null;
|
INDArray gamma = null;
|
||||||
|
INDArray beta = null;
|
||||||
INDArray dGammaView;
|
INDArray dGammaView;
|
||||||
INDArray dBetaView;
|
INDArray dBetaView;
|
||||||
INDArray dGlobalMeanView = gradientViews.get(BatchNormalizationParamInitializer.GLOBAL_MEAN);
|
INDArray dGlobalMeanView = gradientViews.get(BatchNormalizationParamInitializer.GLOBAL_MEAN);
|
||||||
|
@ -129,6 +130,7 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
|
||||||
dBetaView = Nd4j.createUninitialized(dataType, tempShape, 'c');
|
dBetaView = Nd4j.createUninitialized(dataType, tempShape, 'c');
|
||||||
} else {
|
} else {
|
||||||
gamma = getParam(BatchNormalizationParamInitializer.GAMMA);
|
gamma = getParam(BatchNormalizationParamInitializer.GAMMA);
|
||||||
|
beta = getParam(BatchNormalizationParamInitializer.BETA);
|
||||||
dGammaView = gradientViews.get(BatchNormalizationParamInitializer.GAMMA);
|
dGammaView = gradientViews.get(BatchNormalizationParamInitializer.GAMMA);
|
||||||
dBetaView = gradientViews.get(BatchNormalizationParamInitializer.BETA);
|
dBetaView = gradientViews.get(BatchNormalizationParamInitializer.BETA);
|
||||||
}
|
}
|
||||||
|
@ -154,12 +156,12 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
|
||||||
|
|
||||||
Pair<Gradient,INDArray> ret = null;
|
Pair<Gradient,INDArray> ret = null;
|
||||||
try {
|
try {
|
||||||
ret = helper.backpropGradient(in, eps, shape, gamma, dGammaView, dBetaView,
|
ret = helper.backpropGradient(in, eps, shape, gamma, beta, dGammaView, dBetaView,
|
||||||
layerConf.getEps(), workspaceMgr);
|
layerConf.getEps(), workspaceMgr);
|
||||||
} catch (ND4JOpProfilerException e){
|
} catch (ND4JOpProfilerException e){
|
||||||
throw e; //NaN panic etc for debugging
|
throw e; //NaN panic etc for debugging
|
||||||
} catch (Throwable t){
|
} catch (Throwable t){
|
||||||
if(t.getMessage().contains("Failed to allocate")){
|
if(t.getMessage() != null && t.getMessage().contains("Failed to allocate")){
|
||||||
//This is a memory exception - don't fallback to built-in implementation
|
//This is a memory exception - don't fallback to built-in implementation
|
||||||
throw t;
|
throw t;
|
||||||
}
|
}
|
||||||
|
@ -451,7 +453,7 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
|
||||||
} catch (ND4JOpProfilerException e){
|
} catch (ND4JOpProfilerException e){
|
||||||
throw e; //NaN panic etc for debugging
|
throw e; //NaN panic etc for debugging
|
||||||
} catch (Throwable t) {
|
} catch (Throwable t) {
|
||||||
if(t.getMessage().contains("Failed to allocate")){
|
if(t.getMessage() != null && t.getMessage().contains("Failed to allocate")){
|
||||||
//This is a memory exception - don't fallback to built-in implementation
|
//This is a memory exception - don't fallback to built-in implementation
|
||||||
throw t;
|
throw t;
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,8 +31,8 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
public interface BatchNormalizationHelper extends LayerHelper {
|
public interface BatchNormalizationHelper extends LayerHelper {
|
||||||
boolean checkSupported(double eps, boolean fixedGammaBeta);
|
boolean checkSupported(double eps, boolean fixedGammaBeta);
|
||||||
|
|
||||||
Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma,
|
Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, INDArray beta,
|
||||||
INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr workspaceMgr);
|
INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr workspaceMgr);
|
||||||
|
|
||||||
INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean,
|
INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean,
|
||||||
INDArray var, double decay, double eps, LayerWorkspaceMgr workspaceMgr);
|
INDArray var, double decay, double eps, LayerWorkspaceMgr workspaceMgr);
|
||||||
|
|
|
@ -144,7 +144,7 @@ public class LocalResponseNormalization
|
||||||
} catch (ND4JOpProfilerException e){
|
} catch (ND4JOpProfilerException e){
|
||||||
throw e; //NaN panic etc for debugging
|
throw e; //NaN panic etc for debugging
|
||||||
} catch (Throwable t){
|
} catch (Throwable t){
|
||||||
if(t.getMessage().contains("Failed to allocate")){
|
if(t.getMessage() != null && t.getMessage().contains("Failed to allocate")){
|
||||||
//This is a memory exception - don't fallback to built-in implementation
|
//This is a memory exception - don't fallback to built-in implementation
|
||||||
throw t;
|
throw t;
|
||||||
}
|
}
|
||||||
|
@ -211,7 +211,7 @@ public class LocalResponseNormalization
|
||||||
} catch (ND4JOpProfilerException e){
|
} catch (ND4JOpProfilerException e){
|
||||||
throw e; //NaN panic etc for debugging
|
throw e; //NaN panic etc for debugging
|
||||||
} catch (Throwable t){
|
} catch (Throwable t){
|
||||||
if(t.getMessage().contains("Failed to allocate")){
|
if(t.getMessage() != null && t.getMessage().contains("Failed to allocate")){
|
||||||
//This is a memory exception - don't fallback to built-in implementation
|
//This is a memory exception - don't fallback to built-in implementation
|
||||||
throw t;
|
throw t;
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,6 +22,8 @@ import org.deeplearning4j.nn.conf.CacheMode;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.nn.layers.LayerHelper;
|
import org.deeplearning4j.nn.layers.LayerHelper;
|
||||||
|
import org.deeplearning4j.nn.layers.mkldnn.BaseMKLDNNHelper;
|
||||||
|
import org.deeplearning4j.nn.layers.mkldnn.MKLDNNLSTMHelper;
|
||||||
import org.deeplearning4j.nn.params.LSTMParamInitializer;
|
import org.deeplearning4j.nn.params.LSTMParamInitializer;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
@ -73,6 +75,16 @@ public class LSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.L
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
/*
|
||||||
|
//Disabled pending: https://github.com/eclipse/deeplearning4j/issues/8331
|
||||||
|
else if ("CPU".equalsIgnoreCase(backend) && BaseMKLDNNHelper.mklDnnEnabled()){
|
||||||
|
helper = new MKLDNNLSTMHelper();
|
||||||
|
log.debug("MKLDNNLSTMHelper successfully initialized");
|
||||||
|
if (!helper.checkSupported(layerConf().getGateActivationFn(), layerConf().getActivationFn(), false)) {
|
||||||
|
helper = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -215,7 +215,7 @@ public class UNet extends ZooModel {
|
||||||
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
|
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
|
||||||
.activation(Activation.RELU).build(), "conv9-2")
|
.activation(Activation.RELU).build(), "conv9-2")
|
||||||
|
|
||||||
.addLayer("conv10", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(1)
|
.addLayer("conv10", new ConvolutionLayer.Builder(1,1).stride(1,1).nOut(1)
|
||||||
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
|
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
|
||||||
.activation(Activation.IDENTITY).build(), "conv9-3")
|
.activation(Activation.IDENTITY).build(), "conv9-3")
|
||||||
.addLayer("output", new CnnLossLayer.Builder(LossFunctions.LossFunction.XENT)
|
.addLayer("output", new CnnLossLayer.Builder(LossFunctions.LossFunction.XENT)
|
||||||
|
|
|
@ -122,7 +122,7 @@ public interface Listener {
|
||||||
/**
|
/**
|
||||||
* Called when any activation becomes available.
|
* Called when any activation becomes available.
|
||||||
* <p>
|
* <p>
|
||||||
* The activation will most likely be freed later, use detach() if you need to save it.<br>
|
* The activation will most likely be freed later, use dup() if you need to save it.<br>
|
||||||
* <br>
|
* <br>
|
||||||
* Note that this method will be called when any activation becomes available, not just ones from {@link #requiredVariables(SameDiff)}<br>
|
* Note that this method will be called when any activation becomes available, not just ones from {@link #requiredVariables(SameDiff)}<br>
|
||||||
* It is guaranteed to be called for variables from requiredVariables().<br>
|
* It is guaranteed to be called for variables from requiredVariables().<br>
|
||||||
|
|
|
@ -29,6 +29,7 @@ import org.nd4j.autodiff.listeners.*;
|
||||||
import org.nd4j.autodiff.listeners.impl.HistoryListener;
|
import org.nd4j.autodiff.listeners.impl.HistoryListener;
|
||||||
import org.nd4j.autodiff.listeners.records.History;
|
import org.nd4j.autodiff.listeners.records.History;
|
||||||
import org.nd4j.autodiff.listeners.records.LossCurve;
|
import org.nd4j.autodiff.listeners.records.LossCurve;
|
||||||
|
import org.nd4j.autodiff.samediff.api.OutAndGrad;
|
||||||
import org.nd4j.autodiff.samediff.config.BatchOutputConfig;
|
import org.nd4j.autodiff.samediff.config.BatchOutputConfig;
|
||||||
import org.nd4j.autodiff.samediff.config.EvaluationConfig;
|
import org.nd4j.autodiff.samediff.config.EvaluationConfig;
|
||||||
import org.nd4j.autodiff.samediff.config.FitConfig;
|
import org.nd4j.autodiff.samediff.config.FitConfig;
|
||||||
|
@ -1642,7 +1643,13 @@ public class SameDiff extends SDBaseOps {
|
||||||
|
|
||||||
Set<String> requiredVars = new HashSet<>();
|
Set<String> requiredVars = new HashSet<>();
|
||||||
for (Listener l : activeListeners) {
|
for (Listener l : activeListeners) {
|
||||||
requiredVars.addAll(l.requiredVariables(this).trainingVariables());
|
ListenerVariables lv = l.requiredVariables(this);
|
||||||
|
if(lv != null) {
|
||||||
|
Set<String> s = lv.trainingVariables();
|
||||||
|
if(s != null) {
|
||||||
|
requiredVars.addAll(s);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
List<Listener> listenersWitHistory = new ArrayList<>(listeners);
|
List<Listener> listenersWitHistory = new ArrayList<>(listeners);
|
||||||
|
@ -1661,6 +1668,10 @@ public class SameDiff extends SDBaseOps {
|
||||||
TrainingSession ts = new TrainingSession(gradInstance);
|
TrainingSession ts = new TrainingSession(gradInstance);
|
||||||
gradInstance.setTrainingConfig(trainingConfig); //In case any listeners want to use it
|
gradInstance.setTrainingConfig(trainingConfig); //In case any listeners want to use it
|
||||||
|
|
||||||
|
for(Listener l : activeListeners){
|
||||||
|
l.operationStart(gradInstance, Operation.TRAINING);
|
||||||
|
}
|
||||||
|
|
||||||
Set<String> paramsToTrain = new LinkedHashSet<>();
|
Set<String> paramsToTrain = new LinkedHashSet<>();
|
||||||
for(Variable v : variables.values()){
|
for(Variable v : variables.values()){
|
||||||
if(v.getVariable().getVariableType() == VariableType.VARIABLE){
|
if(v.getVariable().getVariableType() == VariableType.VARIABLE){
|
||||||
|
@ -1844,9 +1855,12 @@ public class SameDiff extends SDBaseOps {
|
||||||
*/
|
*/
|
||||||
private void validateListenerActivations(List<Listener> listeners, Operation op) {
|
private void validateListenerActivations(List<Listener> listeners, Operation op) {
|
||||||
for (Listener l : listeners) {
|
for (Listener l : listeners) {
|
||||||
for (String s : l.requiredVariables(this).requiredVariables(op)) {
|
ListenerVariables lv = l.requiredVariables(this);
|
||||||
if (!variables.containsKey(s)) {
|
if(lv != null) {
|
||||||
Preconditions.checkState(false, "Listener %s requested variable %s that is not defined in this SameDiff graph", l, s);
|
for (String s : lv.requiredVariables(op)) {
|
||||||
|
if (!variables.containsKey(s)) {
|
||||||
|
Preconditions.checkState(false, "Listener %s requested variable %s that is not defined in this SameDiff graph", l, s);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2151,31 +2165,20 @@ public class SameDiff extends SDBaseOps {
|
||||||
|
|
||||||
if (hasListeners) {
|
if (hasListeners) {
|
||||||
for (Listener l : activeListeners) {
|
for (Listener l : activeListeners) {
|
||||||
requiredVars.addAll(l.requiredVariables(this).evaluationVariables());
|
ListenerVariables v = l.requiredVariables(this);
|
||||||
|
if(v != null) {
|
||||||
|
requiredVars.addAll(v.evaluationVariables());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
String[] requiredVarsArr = requiredVars.toArray(new String[0]);
|
String[] requiredVarsArr = requiredVars.toArray(new String[0]);
|
||||||
|
|
||||||
while (iterator.hasNext()) {
|
while (iterator.hasNext()) {
|
||||||
long dataStart = hasListeners ? System.currentTimeMillis() : 0;
|
|
||||||
MultiDataSet ds = iterator.next();
|
MultiDataSet ds = iterator.next();
|
||||||
long dataEnd = hasListeners ? System.currentTimeMillis() : 0;
|
|
||||||
Map<String, INDArray> placeholderMap = toPlaceholderMap(ds);
|
Map<String, INDArray> placeholderMap = toPlaceholderMap(ds);
|
||||||
|
|
||||||
Map<String, INDArray> m;
|
Map<String, INDArray> m = directExecHelper(placeholderMap, at, ds, Collections.<String>emptyList(), activeListeners, requiredVarsArr);
|
||||||
Map<String, INDArray> outs = null;
|
|
||||||
if (hasListeners) {
|
|
||||||
|
|
||||||
for (Listener l : activeListeners) {
|
|
||||||
l.iterationStart(this, at, ds, (dataEnd - dataStart));
|
|
||||||
}
|
|
||||||
|
|
||||||
m = directExecHelper(placeholderMap, at, ds, Collections.<String>emptyList(), activeListeners, requiredVarsArr);
|
|
||||||
} else {
|
|
||||||
m = directExecHelper(placeholderMap, at, ds, Collections.<String>emptyList(), activeListeners, requiredVarsArr);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
for (Map.Entry<String, List<IEvaluation>> e : variableEvals.entrySet()) {
|
for (Map.Entry<String, List<IEvaluation>> e : variableEvals.entrySet()) {
|
||||||
INDArray prediction = m.get(e.getKey());
|
INDArray prediction = m.get(e.getKey());
|
||||||
|
@ -2188,15 +2191,6 @@ public class SameDiff extends SDBaseOps {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (hasListeners) {
|
|
||||||
for (Listener l : activeListeners) {
|
|
||||||
Map<String, INDArray> outVars = Maps.newHashMap(
|
|
||||||
Maps.filterKeys(outs,
|
|
||||||
Predicates.in(l.requiredVariables(this).evaluationVariables())));
|
|
||||||
l.iterationDone(this, at, ds, null);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
at.setIteration(at.iteration() + 1);
|
at.setIteration(at.iteration() + 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2518,7 +2512,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
* Special case of {@link #batchOutput()}.
|
* Special case of {@link #batchOutput()}.
|
||||||
*/
|
*/
|
||||||
public Map<String, INDArray> output(Map<String, INDArray> placeholders, @NonNull List<String> outputs) {
|
public Map<String, INDArray> output(Map<String, INDArray> placeholders, @NonNull List<String> outputs) {
|
||||||
return batchOutput().output(outputs.toArray(new String[0])).inputs(placeholders).exec();
|
return batchOutput().output(outputs.toArray(new String[0])).inputs(placeholders).output();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -2529,7 +2523,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
* Special case of {@link #batchOutput()}.
|
* Special case of {@link #batchOutput()}.
|
||||||
*/
|
*/
|
||||||
public Map<String, INDArray> output(Map<String, INDArray> placeholders, String... outputs) {
|
public Map<String, INDArray> output(Map<String, INDArray> placeholders, String... outputs) {
|
||||||
return batchOutput().output(outputs).inputs(placeholders).exec();
|
return batchOutput().output(outputs).inputs(placeholders).output();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -2542,31 +2536,36 @@ public class SameDiff extends SDBaseOps {
|
||||||
* @param listeners Additional listeners to use during this operation.
|
* @param listeners Additional listeners to use during this operation.
|
||||||
* @param outputs The variables to output and return.
|
* @param outputs The variables to output and return.
|
||||||
*/
|
*/
|
||||||
public Map<String, INDArray> output(Map<String, INDArray> placeholders, @NonNull List<Listener> listeners, String... outputs) {
|
public Map<String, INDArray> output(Map<String, INDArray> placeholders, List<Listener> listeners, String... outputs) {
|
||||||
return batchOutputHelper(placeholders, listeners, outputs);
|
return batchOutputHelper(placeholders, listeners, Operation.INFERENCE, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected Map<String, INDArray> batchOutputHelper(Map<String, INDArray> placeholders, @NonNull List<Listener> listeners, String... outputs) {
|
protected Map<String, INDArray> batchOutputHelper(Map<String, INDArray> placeholders, List<Listener> listeners, Operation operation, String... outputs) {
|
||||||
List<Listener> activeListeners = new ArrayList<>();
|
List<Listener> activeListeners = new ArrayList<>();
|
||||||
|
|
||||||
|
if(operation == null)
|
||||||
|
operation = Operation.INFERENCE;
|
||||||
|
|
||||||
for (Listener l : this.listeners)
|
for (Listener l : this.listeners)
|
||||||
if (l.isActive(Operation.INFERENCE))
|
if (l.isActive(operation))
|
||||||
activeListeners.add(l);
|
activeListeners.add(l);
|
||||||
|
|
||||||
for (Listener l : listeners)
|
if(listeners != null) {
|
||||||
if (l.isActive(Operation.INFERENCE))
|
for (Listener l : listeners)
|
||||||
activeListeners.add(l);
|
if (l.isActive(operation))
|
||||||
|
activeListeners.add(l);
|
||||||
for (Listener l : activeListeners) {
|
|
||||||
l.operationStart(this, Operation.INFERENCE);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
validateListenerActivations(activeListeners, Operation.INFERENCE);
|
for (Listener l : activeListeners) {
|
||||||
|
l.operationStart(this, operation);
|
||||||
|
}
|
||||||
|
|
||||||
Map<String, INDArray> ret = directExecHelper(placeholders, At.defaultAt(Operation.INFERENCE), null, Collections.<String>emptyList(), activeListeners, outputs);
|
validateListenerActivations(activeListeners, operation);
|
||||||
|
|
||||||
|
Map<String, INDArray> ret = directExecHelper(placeholders, At.defaultAt(operation), null, Collections.<String>emptyList(), activeListeners, outputs);
|
||||||
|
|
||||||
for (Listener l : activeListeners) {
|
for (Listener l : activeListeners) {
|
||||||
l.operationEnd(this, Operation.INFERENCE);
|
l.operationEnd(this, operation);
|
||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
@ -3992,7 +3991,6 @@ public class SameDiff extends SDBaseOps {
|
||||||
|
|
||||||
sameDiffFunctionInstances.put(function, sub);
|
sameDiffFunctionInstances.put(function, sub);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -4012,32 +4010,64 @@ public class SameDiff extends SDBaseOps {
|
||||||
*/
|
*/
|
||||||
public Map<String, INDArray> calculateGradients(Map<String, INDArray> placeholderVals, @NonNull Collection<String> variables) {
|
public Map<String, INDArray> calculateGradients(Map<String, INDArray> placeholderVals, @NonNull Collection<String> variables) {
|
||||||
Preconditions.checkArgument(!variables.isEmpty(), "No variables were specified");
|
Preconditions.checkArgument(!variables.isEmpty(), "No variables were specified");
|
||||||
|
OutAndGrad oag = calculateGradientsAndOutputs(placeholderVals, null, variables);
|
||||||
|
return oag.getGradients();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate the activations and the gradients for the specified variables, in one execution call.
|
||||||
|
* This is equivalent to calling {@link #output(Map, List)} and {@link #calculateGradients(Map, Collection)}, but
|
||||||
|
* is more efficient than calling both separately.
|
||||||
|
*
|
||||||
|
* @param placeholderVals Placeholders. May be null
|
||||||
|
* @param outputVars Names of the variables that you want the activations/outputs for. May be null
|
||||||
|
* @param gradientVars Names of the variables that you want the gradient arrays for. May be null
|
||||||
|
* @return Activations and gradients, keyed by variable name
|
||||||
|
*/
|
||||||
|
public OutAndGrad calculateGradientsAndOutputs(Map<String,INDArray> placeholderVals, Collection<String> outputVars, Collection<String> gradientVars){
|
||||||
|
Preconditions.checkArgument((outputVars != null && !outputVars.isEmpty()) || (gradientVars != null && !gradientVars.isEmpty()),
|
||||||
|
"No variables were specified for either output or gradients");
|
||||||
if (getFunction(GRAD_FN_KEY) == null) {
|
if (getFunction(GRAD_FN_KEY) == null) {
|
||||||
createGradFunction();
|
createGradFunction();
|
||||||
}
|
}
|
||||||
|
|
||||||
List<String> gradVarNames = new ArrayList<>(variables.size());
|
List<String> varNames = new ArrayList<>();
|
||||||
for (String s : variables) {
|
if(outputVars != null){
|
||||||
Preconditions.checkState(this.variables.containsKey(s), "No variable with name \"%s\" exists in the SameDiff instance", s);
|
varNames.addAll(outputVars);
|
||||||
SDVariable v = getVariable(s).getGradient();
|
}
|
||||||
if (v != null) {
|
if(gradientVars != null) {
|
||||||
//In a few cases (like loss not depending on trainable parameters) we won't have gradient array for parameter variable
|
for (String s : gradientVars) {
|
||||||
gradVarNames.add(v.name());
|
Preconditions.checkState(this.variables.containsKey(s), "No variable with name \"%s\" exists in the SameDiff instance", s);
|
||||||
|
SDVariable v = getVariable(s).getGradient();
|
||||||
|
if (v != null) {
|
||||||
|
//In a few cases (like loss not depending on trainable parameters) we won't have gradient array for parameter variable
|
||||||
|
varNames.add(v.name());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//Key is gradient variable name
|
//Key is gradient variable name
|
||||||
Map<String, INDArray> grads = getFunction(GRAD_FN_KEY).output(placeholderVals, gradVarNames);
|
SameDiff gradFn = getFunction(GRAD_FN_KEY);
|
||||||
|
gradFn.setListeners(listeners);
|
||||||
|
Map<String, INDArray> grads = gradFn.batchOutputHelper(placeholderVals, null, Operation.TRAINING, varNames.toArray(new String[0]));
|
||||||
|
|
||||||
Map<String, INDArray> out = new HashMap<>();
|
Map<String, INDArray> outOutputs = outputVars == null ? null : new HashMap<String,INDArray>();
|
||||||
for (String s : variables) {
|
Map<String, INDArray> outGrads = gradientVars == null ? null : new HashMap<String,INDArray>();
|
||||||
if (getVariable(s).getGradient() != null) {
|
if(outputVars != null){
|
||||||
String gradVar = getVariable(s).getGradient().name();
|
for(String s : outputVars){
|
||||||
out.put(s, grads.get(gradVar));
|
outOutputs.put(s, grads.get(s));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if(gradientVars != null) {
|
||||||
|
for (String s : gradientVars) {
|
||||||
|
if (getVariable(s).getGradient() != null) {
|
||||||
|
String gradVar = getVariable(s).getGradient().name();
|
||||||
|
outGrads.put(s, grads.get(gradVar));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return out;
|
return new OutAndGrad(outOutputs, outGrads);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
package org.nd4j.autodiff.samediff.api;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Data;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A simple object holding two maps - one of output arrays, another of gradient arrays
|
||||||
|
*/
|
||||||
|
@AllArgsConstructor
|
||||||
|
@Data
|
||||||
|
public class OutAndGrad {
|
||||||
|
|
||||||
|
private final Map<String, INDArray> outputs;
|
||||||
|
private final Map<String, INDArray> gradients;
|
||||||
|
|
||||||
|
}
|
|
@ -797,7 +797,7 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
||||||
} else if (v.getVariableType() == VariableType.VARIABLE) {
|
} else if (v.getVariableType() == VariableType.VARIABLE) {
|
||||||
args[i] = v.getArr();
|
args[i] = v.getArr();
|
||||||
} else if (v.isPlaceHolder()) {
|
} else if (v.isPlaceHolder()) {
|
||||||
Preconditions.checkState(placeholderValues != null && placeholderValues.containsKey(s), "No array provided for placeholder %s", s);
|
Preconditions.checkState(placeholderValues != null && placeholderValues.containsKey(s), "No array was provided for required placeholder variable \"%s\"", s);
|
||||||
args[i] = placeholderValues.get(s);
|
args[i] = placeholderValues.get(s);
|
||||||
} else {
|
} else {
|
||||||
VarId vid = lookup(s, opInputs, allIterInputs, true);
|
VarId vid = lookup(s, opInputs, allIterInputs, true);
|
||||||
|
|
|
@ -20,6 +20,7 @@ import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.DataSet;
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
@ -125,7 +126,9 @@ public class ImagePreProcessingScaler implements DataNormalization {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void transformLabel(INDArray label) {
|
public void transformLabel(INDArray label) {
|
||||||
//No op
|
Preconditions.checkState(label != null && label.rank() == 4, "Labels can only be transformed for segmentation use" +
|
||||||
|
" cases using this preprocesser - i.e., labels must be rank 4. Got: %ndShape", label);
|
||||||
|
transform(label);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -161,7 +164,9 @@ public class ImagePreProcessingScaler implements DataNormalization {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void revertLabels(INDArray labels) {
|
public void revertLabels(INDArray labels) {
|
||||||
//No op
|
Preconditions.checkState(labels != null && labels.rank() == 4, "Labels can only be transformed for segmentation use" +
|
||||||
|
" cases using this preprocesser - i.e., labels must be rank 4. Got: %ndShape", labels);
|
||||||
|
revertFeatures(labels);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -171,9 +176,7 @@ public class ImagePreProcessingScaler implements DataNormalization {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void fitLabel(boolean fitLabels) {
|
public void fitLabel(boolean fitLabels) {
|
||||||
if (fitLabels) {
|
//No-op
|
||||||
log.warn("Labels fitting not currently supported for ImagePreProcessingScaler. Labels will not be modified");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -5831,12 +5831,6 @@ public class Nd4j {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case UBYTE:
|
case UBYTE:
|
||||||
UInt8Buffer b = new UInt8Buffer(ArrayUtil.prod(shapeOf));
|
|
||||||
val sb = bb.order(_order).asReadOnlyBuffer();
|
|
||||||
for (int e = 0; e < prod; e++)
|
|
||||||
b.put(e, sb.get(e));
|
|
||||||
|
|
||||||
return Nd4j.create(b, shapeOf);
|
|
||||||
case BFLOAT16:
|
case BFLOAT16:
|
||||||
case UINT16:
|
case UINT16:
|
||||||
INDArray arr = Nd4j.createUninitialized(_dtype, shapeOf);
|
INDArray arr = Nd4j.createUninitialized(_dtype, shapeOf);
|
||||||
|
|
|
@ -1000,8 +1000,13 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
val dataType = op.resultType();
|
val dataType = op.resultType();
|
||||||
|
|
||||||
val ret = Nd4j.createUninitialized(dataType, retShape);
|
if( op.z() == null ){
|
||||||
op.setZ(ret);
|
val ret = Nd4j.createUninitialized(dataType, retShape);
|
||||||
|
op.setZ(ret);
|
||||||
|
} else if(op.z().dataType() != dataType || !Arrays.equals(retShape, op.z().shape())){
|
||||||
|
throw new ND4JIllegalStateException("Output array for op " + op.getClass().getSimpleName() + " should have type " + dataType + " and shape " + Arrays.toString(retShape)
|
||||||
|
+ " but has datatype " + op.z().dataType() + " and shape " + Arrays.toString(op.z().shape()));
|
||||||
|
}
|
||||||
|
|
||||||
val eb = op.extraArgsDataBuff(op.z().dataType() == DataType.BOOL || op.getOpType() == Op.Type.REDUCE_LONG ? op.x().dataType() : op.z().dataType());
|
val eb = op.extraArgsDataBuff(op.z().dataType() == DataType.BOOL || op.getOpType() == Op.Type.REDUCE_LONG ? op.x().dataType() : op.z().dataType());
|
||||||
Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(eb, context) : null;
|
Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(eb, context) : null;
|
||||||
|
|
|
@ -39,6 +39,7 @@ import org.junit.Ignore;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
import org.junit.rules.TemporaryFolder;
|
||||||
import org.nd4j.OpValidationSuite;
|
import org.nd4j.OpValidationSuite;
|
||||||
|
import org.nd4j.autodiff.samediff.api.OutAndGrad;
|
||||||
import org.nd4j.autodiff.samediff.impl.DefaultSameDiffConditional;
|
import org.nd4j.autodiff.samediff.impl.DefaultSameDiffConditional;
|
||||||
import org.nd4j.autodiff.validation.OpValidation;
|
import org.nd4j.autodiff.validation.OpValidation;
|
||||||
import org.nd4j.autodiff.validation.TestCase;
|
import org.nd4j.autodiff.validation.TestCase;
|
||||||
|
@ -3426,4 +3427,30 @@ public class SameDiffTests extends BaseNd4jTest {
|
||||||
INDArray a1 = rand1.eval();
|
INDArray a1 = rand1.eval();
|
||||||
assertEquals(a0, a1);
|
assertEquals(a0, a1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCalculateGradientsAndOutputs(){
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4);
|
||||||
|
SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 4, 3));
|
||||||
|
SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 3));
|
||||||
|
SDVariable z = in.mmul(w).add("z", b);
|
||||||
|
SDVariable softmax = sd.nn.softmax("softmax", z);
|
||||||
|
|
||||||
|
Map<String,INDArray> ph = Collections.singletonMap("in", Nd4j.rand(DataType.FLOAT, 2, 4));
|
||||||
|
List<String> outputs = Arrays.asList("in", "z", "softmax");
|
||||||
|
List<String> grads = Arrays.asList("in", "w", "z");
|
||||||
|
|
||||||
|
OutAndGrad oag = sd.calculateGradientsAndOutputs(ph, outputs, grads);
|
||||||
|
Map<String,INDArray> outs = oag.getOutputs();
|
||||||
|
Map<String,INDArray> g = oag.getGradients();
|
||||||
|
|
||||||
|
|
||||||
|
Map<String,INDArray> outExp = sd.output(ph, outputs);
|
||||||
|
Map<String,INDArray> gExp = sd.calculateGradients(ph, grads);
|
||||||
|
|
||||||
|
assertEquals(outExp, outs);
|
||||||
|
assertEquals(gExp, g);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,22 +19,28 @@ package org.nd4j.autodiff.samediff.listeners;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.*;
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.autodiff.listeners.Operation;
|
import org.nd4j.autodiff.listeners.*;
|
||||||
import org.nd4j.autodiff.listeners.impl.ScoreListener;
|
import org.nd4j.autodiff.listeners.impl.ScoreListener;
|
||||||
import org.nd4j.autodiff.listeners.records.History;
|
import org.nd4j.autodiff.listeners.records.History;
|
||||||
|
import org.nd4j.autodiff.listeners.records.LossCurve;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.autodiff.samediff.TrainingConfig;
|
import org.nd4j.autodiff.samediff.TrainingConfig;
|
||||||
|
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||||
|
import org.nd4j.autodiff.samediff.internal.Variable;
|
||||||
import org.nd4j.evaluation.classification.Evaluation;
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
import org.nd4j.evaluation.classification.Evaluation.Metric;
|
import org.nd4j.evaluation.classification.Evaluation.Metric;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.IrisDataSetIterator;
|
import org.nd4j.linalg.dataset.IrisDataSetIterator;
|
||||||
|
import org.nd4j.linalg.dataset.adapter.SingletonDataSetIterator;
|
||||||
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
|
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -49,6 +55,11 @@ public class ListenerTest extends BaseNd4jTest {
|
||||||
super(backend);
|
super(backend);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public char ordering() {
|
||||||
|
return 'c';
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void irisHistoryTest() {
|
public void irisHistoryTest() {
|
||||||
|
|
||||||
|
@ -112,8 +123,237 @@ public class ListenerTest extends BaseNd4jTest {
|
||||||
assertTrue("Accuracy < 75%, was " + acc, acc >= 0.75);
|
assertTrue("Accuracy < 75%, was " + acc, acc >= 0.75);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Test
|
||||||
public char ordering() {
|
public void testListenerCalls(){
|
||||||
return 'c';
|
SameDiff sd = SameDiff.create();
|
||||||
|
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4);
|
||||||
|
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 3);
|
||||||
|
SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 4, 3));
|
||||||
|
SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 3));
|
||||||
|
SDVariable z = in.mmul(w).add(b);
|
||||||
|
SDVariable softmax = sd.nn.softmax("softmax", z);
|
||||||
|
SDVariable loss = sd.loss.logLoss("loss" ,label, softmax);
|
||||||
|
|
||||||
|
TestListener tl = new TestListener(Operation.INFERENCE);
|
||||||
|
sd.setListeners(tl);
|
||||||
|
|
||||||
|
//Check listener called during inference
|
||||||
|
Map<String,INDArray> phMap = Collections.singletonMap("in", Nd4j.rand(1, 4));
|
||||||
|
|
||||||
|
for( int i=1; i<=5; i++ ) {
|
||||||
|
INDArray out = sd.outputSingle(phMap, "softmax");
|
||||||
|
|
||||||
|
assertEquals(0, tl.epochStartCount);
|
||||||
|
assertEquals(0, tl.epochEndCount);
|
||||||
|
assertEquals(0, tl.validationDoneCount);
|
||||||
|
assertEquals(0, tl.iterationStartCount);
|
||||||
|
assertEquals(0, tl.iterationDoneCount);
|
||||||
|
assertEquals(Collections.singletonMap(Operation.INFERENCE, i), tl.operationStartCount);
|
||||||
|
assertEquals(Collections.singletonMap(Operation.INFERENCE, i), tl.operationEndCount);
|
||||||
|
assertEquals(3*i, tl.preOpExecutionCount); //mmul, add, softmax
|
||||||
|
assertEquals(3*i, tl.opExecutionCount);
|
||||||
|
assertEquals(3*i, tl.activationAvailableCount); //mmul, add, softmax outputs
|
||||||
|
assertEquals(0, tl.preUpdateCount); //Inference -> no updating
|
||||||
|
}
|
||||||
|
|
||||||
|
//Check listener NOT called during inference when set to Operation.TRAINING
|
||||||
|
tl = new TestListener(Operation.TRAINING);
|
||||||
|
sd.setListeners(tl);
|
||||||
|
sd.outputSingle(phMap, "softmax");
|
||||||
|
|
||||||
|
assertEquals(0, tl.epochStartCount);
|
||||||
|
assertEquals(0, tl.epochEndCount);
|
||||||
|
assertEquals(0, tl.validationDoneCount);
|
||||||
|
assertEquals(0, tl.iterationStartCount);
|
||||||
|
assertEquals(0, tl.iterationDoneCount);
|
||||||
|
assertEquals(Collections.emptyMap(), tl.operationStartCount);
|
||||||
|
assertEquals(Collections.emptyMap(), tl.operationEndCount);
|
||||||
|
assertEquals(0, tl.preOpExecutionCount);
|
||||||
|
assertEquals(0, tl.opExecutionCount);
|
||||||
|
assertEquals(0, tl.activationAvailableCount);
|
||||||
|
assertEquals(0, tl.preUpdateCount);
|
||||||
|
|
||||||
|
//Check listener called during gradient calculation
|
||||||
|
tl = new TestListener(Operation.TRAINING);
|
||||||
|
sd.setListeners(tl);
|
||||||
|
phMap = new HashMap<>();
|
||||||
|
phMap.put("in", Nd4j.rand( DataType.FLOAT, 1, 4));
|
||||||
|
phMap.put("label", Nd4j.createFromArray(0f, 1f, 0f).reshape(1, 3));
|
||||||
|
|
||||||
|
for( int i=1; i<=3; i++ ) {
|
||||||
|
sd.calculateGradients(phMap, "in", "w", "b");
|
||||||
|
assertEquals(0, tl.epochStartCount);
|
||||||
|
assertEquals(0, tl.epochEndCount);
|
||||||
|
assertEquals(0, tl.validationDoneCount);
|
||||||
|
assertEquals(0, tl.iterationStartCount);
|
||||||
|
assertEquals(0, tl.iterationDoneCount);
|
||||||
|
assertEquals(Collections.singletonMap(Operation.TRAINING, i), tl.operationStartCount);
|
||||||
|
assertEquals(Collections.singletonMap(Operation.TRAINING, i), tl.operationEndCount);
|
||||||
|
assertEquals(7*i, tl.preOpExecutionCount); //mmul, add, softmax, loss grad, softmax backward, add backward, mmul backward
|
||||||
|
assertEquals(7*i, tl.opExecutionCount);
|
||||||
|
assertEquals(11*i, tl.activationAvailableCount); //mmul, add, softmax, loss grad (weight, in, label), softmax bp, add backward (z, b), mmul (in, w)
|
||||||
|
assertEquals(0, tl.preUpdateCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//Check listener NOT called during gradient calculation - when listener is still set to INFERENCE mode
|
||||||
|
tl = new TestListener(Operation.INFERENCE);
|
||||||
|
sd.setListeners(tl);
|
||||||
|
for( int i=1; i<=3; i++ ) {
|
||||||
|
sd.calculateGradients(phMap, "in", "w", "b");
|
||||||
|
assertEquals(0, tl.epochStartCount);
|
||||||
|
assertEquals(0, tl.epochEndCount);
|
||||||
|
assertEquals(0, tl.validationDoneCount);
|
||||||
|
assertEquals(0, tl.iterationStartCount);
|
||||||
|
assertEquals(0, tl.iterationDoneCount);
|
||||||
|
assertEquals(Collections.emptyMap(), tl.operationStartCount);
|
||||||
|
assertEquals(Collections.emptyMap(), tl.operationEndCount);
|
||||||
|
assertEquals(0, tl.preOpExecutionCount);
|
||||||
|
assertEquals(0, tl.opExecutionCount);
|
||||||
|
assertEquals(0, tl.activationAvailableCount);
|
||||||
|
assertEquals(0, tl.preUpdateCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
//Check fit:
|
||||||
|
tl = new TestListener(Operation.TRAINING);
|
||||||
|
sd.setListeners(tl);
|
||||||
|
sd.setTrainingConfig(TrainingConfig.builder()
|
||||||
|
.dataSetFeatureMapping("in")
|
||||||
|
.dataSetLabelMapping("label")
|
||||||
|
.updater(new Adam(1e-3))
|
||||||
|
.build());
|
||||||
|
|
||||||
|
SingletonDataSetIterator dsi = new SingletonDataSetIterator(new DataSet(phMap.get("in"), phMap.get("label")));
|
||||||
|
for( int i=1; i<=3; i++ ) {
|
||||||
|
sd.fit(dsi, 1);
|
||||||
|
assertEquals(i, tl.epochStartCount);
|
||||||
|
assertEquals(i, tl.epochEndCount);
|
||||||
|
assertEquals(0, tl.validationDoneCount);
|
||||||
|
assertEquals(i, tl.iterationStartCount);
|
||||||
|
assertEquals(i, tl.iterationDoneCount);
|
||||||
|
assertEquals(Collections.singletonMap(Operation.TRAINING, i), tl.operationStartCount);
|
||||||
|
assertEquals(Collections.singletonMap(Operation.TRAINING, i), tl.operationEndCount);
|
||||||
|
assertEquals(7*i, tl.preOpExecutionCount); //mmul, add, softmax, loss grad, softmax backward, add backward, mmul backward
|
||||||
|
assertEquals(7*i, tl.opExecutionCount);
|
||||||
|
assertEquals(11*i, tl.activationAvailableCount); //mmul, add, softmax, loss grad (weight, in, label), softmax bp, add backward (z, b), mmul (in, w)
|
||||||
|
assertEquals(2*i, tl.preUpdateCount); //w, b
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//Check evaluation:
|
||||||
|
tl = new TestListener(Operation.EVALUATION);
|
||||||
|
sd.setListeners(tl);
|
||||||
|
|
||||||
|
for( int i=1; i<=3; i++ ) {
|
||||||
|
sd.evaluate(dsi, "softmax", new Evaluation());
|
||||||
|
assertEquals(0, tl.epochStartCount);
|
||||||
|
assertEquals(0, tl.epochEndCount);
|
||||||
|
assertEquals(0, tl.validationDoneCount);
|
||||||
|
assertEquals(0, tl.iterationStartCount);
|
||||||
|
assertEquals(0, tl.iterationDoneCount);
|
||||||
|
assertEquals(Collections.singletonMap(Operation.EVALUATION, i), tl.operationStartCount);
|
||||||
|
assertEquals(Collections.singletonMap(Operation.EVALUATION, i), tl.operationEndCount);
|
||||||
|
assertEquals(3*i, tl.preOpExecutionCount); //mmul, add, softmax
|
||||||
|
assertEquals(3*i, tl.opExecutionCount);
|
||||||
|
assertEquals(3*i, tl.activationAvailableCount); //mmul, add, softmax
|
||||||
|
assertEquals(0, tl.preUpdateCount); //w, b
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class TestListener implements Listener {
|
||||||
|
|
||||||
|
public TestListener(Operation operation){
|
||||||
|
this.operation = operation;
|
||||||
|
}
|
||||||
|
|
||||||
|
private final Operation operation;
|
||||||
|
|
||||||
|
private int epochStartCount = 0;
|
||||||
|
private int epochEndCount = 0;
|
||||||
|
private int validationDoneCount = 0;
|
||||||
|
private int iterationStartCount = 0;
|
||||||
|
private int iterationDoneCount = 0;
|
||||||
|
private Map<Operation,Integer> operationStartCount = new HashMap<>();
|
||||||
|
private Map<Operation,Integer> operationEndCount = new HashMap<>();
|
||||||
|
private int preOpExecutionCount = 0;
|
||||||
|
private int opExecutionCount = 0;
|
||||||
|
private int activationAvailableCount = 0;
|
||||||
|
private int preUpdateCount = 0;
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ListenerVariables requiredVariables(SameDiff sd) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isActive(Operation operation) {
|
||||||
|
return this.operation == null || this.operation == operation;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void epochStart(SameDiff sd, At at) {
|
||||||
|
epochStartCount++;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ListenerResponse epochEnd(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis) {
|
||||||
|
epochEndCount++;
|
||||||
|
return ListenerResponse.CONTINUE;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ListenerResponse validationDone(SameDiff sd, At at, long validationTimeMillis) {
|
||||||
|
validationDoneCount++;
|
||||||
|
return ListenerResponse.CONTINUE;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void iterationStart(SameDiff sd, At at, MultiDataSet data, long etlTimeMs) {
|
||||||
|
iterationStartCount++;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void iterationDone(SameDiff sd, At at, MultiDataSet dataSet, Loss loss) {
|
||||||
|
iterationDoneCount++;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void operationStart(SameDiff sd, Operation op) {
|
||||||
|
if(!operationStartCount.containsKey(op)) {
|
||||||
|
operationStartCount.put(op, 1);
|
||||||
|
} else {
|
||||||
|
operationStartCount.put(op, operationStartCount.get(op) + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void operationEnd(SameDiff sd, Operation op) {
|
||||||
|
if(!operationEndCount.containsKey(op)) {
|
||||||
|
operationEndCount.put(op, 1);
|
||||||
|
} else {
|
||||||
|
operationEndCount.put(op, operationEndCount.get(op) + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void preOpExecution(SameDiff sd, At at, SameDiffOp op) {
|
||||||
|
preOpExecutionCount++;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
|
||||||
|
opExecutionCount++;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void activationAvailable(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, String varName, INDArray activation) {
|
||||||
|
activationAvailableCount++;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void preUpdate(SameDiff sd, At at, Variable v, INDArray update) {
|
||||||
|
preUpdateCount++;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,11 +20,13 @@ import org.junit.Test;
|
||||||
import org.junit.runner.RunWith;
|
import org.junit.runner.RunWith;
|
||||||
import org.junit.runners.Parameterized;
|
import org.junit.runners.Parameterized;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.preprocessor.ImageMultiPreProcessingScaler;
|
import org.nd4j.linalg.dataset.api.preprocessor.ImageMultiPreProcessingScaler;
|
||||||
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
|
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
@ -42,32 +44,32 @@ public class ImagePreProcessortTest extends BaseNd4jTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void simpleImageTest() {
|
public void simpleImageTest() {
|
||||||
INDArray rChannels = Nd4j.zeros(10, 10).addi(128);
|
INDArray rChannels = Nd4j.zeros(DataType.FLOAT, 10, 10).addi(128);
|
||||||
INDArray gChannels = Nd4j.zeros(10, 10).addi(64);
|
INDArray gChannels = Nd4j.zeros(DataType.FLOAT, 10, 10).addi(64);
|
||||||
INDArray bChannels = Nd4j.zeros(10, 10).addi(255);
|
INDArray bChannels = Nd4j.zeros(DataType.FLOAT, 10, 10).addi(255);
|
||||||
INDArray image = Nd4j.vstack(rChannels, gChannels, bChannels).reshape(3, 10, 10);
|
INDArray image = Nd4j.vstack(rChannels, gChannels, bChannels).reshape(1, 3, 10, 10);
|
||||||
INDArray orig = image.dup();
|
INDArray orig = image.dup();
|
||||||
|
|
||||||
//System.out.println(Arrays.toString(image.shape()));
|
//System.out.println(Arrays.toString(image.shape()));
|
||||||
DataSet ds = new DataSet(image.reshape(1, 3, 10, 10), Nd4j.ones(1, 1));
|
DataSet ds = new DataSet(image, Nd4j.ones(1, 1));
|
||||||
ImagePreProcessingScaler myScaler = new ImagePreProcessingScaler();
|
ImagePreProcessingScaler myScaler = new ImagePreProcessingScaler();
|
||||||
//So this should scale to 0.5,0.25 and 1;
|
//So this should scale to 0.5,0.25 and 1;
|
||||||
INDArray expected = image.mul(0);
|
INDArray expected = image.mul(0);
|
||||||
expected.slice(0, 0).addi(0.5);
|
expected.slice(0, 1).addi(0.5);
|
||||||
expected.slice(1, 0).addi(0.25);
|
expected.slice(1, 1).addi(0.25);
|
||||||
expected.slice(2, 0).addi(1.0);
|
expected.slice(2, 1).addi(1.0);
|
||||||
myScaler.transform(ds);
|
myScaler.transform(ds);
|
||||||
assertTrue(Transforms.abs(ds.getFeatures().sub(expected)).maxNumber().doubleValue() <= 0.01);
|
assertTrue(Transforms.abs(ds.getFeatures().sub(expected)).maxNumber().doubleValue() <= 0.01);
|
||||||
|
|
||||||
//Now giving it 16 bits instead of the default
|
//Now giving it 16 bits instead of the default
|
||||||
//System.out.println(Arrays.toString(image.shape()));
|
//System.out.println(Arrays.toString(image.shape()));
|
||||||
ds = new DataSet(image.reshape(1, 3, 10, 10), Nd4j.ones(1, 1));
|
ds = new DataSet(image, Nd4j.ones(1, 1));
|
||||||
myScaler = new ImagePreProcessingScaler(0, 1, 16);
|
myScaler = new ImagePreProcessingScaler(0, 1, 16);
|
||||||
//So this should scale to 0.5,0.25 and 1;
|
//So this should scale to 0.5,0.25 and 1;
|
||||||
expected = image.mul(0);
|
expected = image.mul(0);
|
||||||
expected.slice(0, 0).addi(0.5 / 256);
|
expected.slice(0, 1).addi(0.5 / 256);
|
||||||
expected.slice(1, 0).addi(0.25 / 256);
|
expected.slice(1, 1).addi(0.25 / 256);
|
||||||
expected.slice(2, 0).addi(1.0 / 256);
|
expected.slice(2, 1).addi(1.0 / 256);
|
||||||
myScaler.transform(ds);
|
myScaler.transform(ds);
|
||||||
assertTrue(Transforms.abs(ds.getFeatures().sub(expected)).maxNumber().doubleValue() <= 0.01);
|
assertTrue(Transforms.abs(ds.getFeatures().sub(expected)).maxNumber().doubleValue() <= 0.01);
|
||||||
|
|
||||||
|
@ -88,6 +90,16 @@ public class ImagePreProcessortTest extends BaseNd4jTest {
|
||||||
myScaler.transform(before);
|
myScaler.transform(before);
|
||||||
myScaler.revertFeatures(before);
|
myScaler.revertFeatures(before);
|
||||||
assertEquals(orig, before);
|
assertEquals(orig, before);
|
||||||
|
|
||||||
|
|
||||||
|
//Test labels (segmentation case)
|
||||||
|
before = orig.dup();
|
||||||
|
myScaler = new ImagePreProcessingScaler(0, 1);
|
||||||
|
myScaler.transformLabel(before);
|
||||||
|
expected = orig.div(255);
|
||||||
|
assertEquals(expected, before);
|
||||||
|
myScaler.revertLabels(before);
|
||||||
|
assertEquals(orig, before);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
Loading…
Reference in New Issue