parent
0bed17c97f
commit
997143b9dd
|
@ -88,7 +88,7 @@ public class CNN1DTestCases {
|
||||||
.convolutionMode(ConvolutionMode.Same))
|
.convolutionMode(ConvolutionMode.Same))
|
||||||
.graphBuilder()
|
.graphBuilder()
|
||||||
.addInputs("in")
|
.addInputs("in")
|
||||||
.layer("0", Convolution1DLayer.builder().nOut(32).activation(Activation.TANH).kernelSize(3).stride(1).build(), "in")
|
.layer("0", Convolution1D.builder().nOut(32).activation(Activation.TANH).kernelSize(3).stride(1).build(), "in")
|
||||||
.layer("1", Subsampling1DLayer.builder().kernelSize(2).stride(1).poolingType(SubsamplingLayer.PoolingType.MAX.toPoolingType()).build(), "0")
|
.layer("1", Subsampling1DLayer.builder().kernelSize(2).stride(1).poolingType(SubsamplingLayer.PoolingType.MAX.toPoolingType()).build(), "0")
|
||||||
.layer("2", Cropping1D.builder(1).build(), "1")
|
.layer("2", Cropping1D.builder(1).build(), "1")
|
||||||
.layer("3", ZeroPadding1DLayer.builder(1).build(), "2")
|
.layer("3", ZeroPadding1DLayer.builder(1).build(), "2")
|
||||||
|
|
|
@ -77,7 +77,7 @@ public class BNGradientCheckTest extends BaseDL4JTest {
|
||||||
NeuralNetConfiguration.builder().updater(new NoOp())
|
NeuralNetConfiguration.builder().updater(new NoOp())
|
||||||
.dataType(DataType.DOUBLE)
|
.dataType(DataType.DOUBLE)
|
||||||
.seed(12345L)
|
.seed(12345L)
|
||||||
.dist(new NormalDistribution(0, 1)).list()
|
.weightInit(new NormalDistribution(0, 1))
|
||||||
.layer(0, DenseLayer.builder().nIn(4).nOut(3)
|
.layer(0, DenseLayer.builder().nIn(4).nOut(3)
|
||||||
.activation(Activation.IDENTITY).build())
|
.activation(Activation.IDENTITY).build())
|
||||||
.layer(1,BatchNormalization.builder().useLogStd(useLogStd).nOut(3).build())
|
.layer(1,BatchNormalization.builder().useLogStd(useLogStd).nOut(3).build())
|
||||||
|
|
|
@ -93,7 +93,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.convolutionMode(ConvolutionMode.Same)
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
.list()
|
.list()
|
||||||
.layer(
|
.layer(
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.activation(afn)
|
.activation(afn)
|
||||||
.kernelSize(kernel)
|
.kernelSize(kernel)
|
||||||
.stride(stride)
|
.stride(stride)
|
||||||
|
@ -202,7 +202,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.dist(new NormalDistribution(0, 1))
|
.dist(new NormalDistribution(0, 1))
|
||||||
.convolutionMode(ConvolutionMode.Same)
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
.layer(
|
.layer(
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.activation(afn)
|
.activation(afn)
|
||||||
.kernelSize(kernel)
|
.kernelSize(kernel)
|
||||||
.stride(stride)
|
.stride(stride)
|
||||||
|
@ -211,7 +211,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.build())
|
.build())
|
||||||
.layer(Cropping1D.builder(cropping).build())
|
.layer(Cropping1D.builder(cropping).build())
|
||||||
.layer(
|
.layer(
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.activation(afn)
|
.activation(afn)
|
||||||
.kernelSize(kernel)
|
.kernelSize(kernel)
|
||||||
.stride(stride)
|
.stride(stride)
|
||||||
|
@ -317,7 +317,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.dist(new NormalDistribution(0, 1))
|
.dist(new NormalDistribution(0, 1))
|
||||||
.convolutionMode(ConvolutionMode.Same)
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
.layer(
|
.layer(
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.activation(afn)
|
.activation(afn)
|
||||||
.kernelSize(kernel)
|
.kernelSize(kernel)
|
||||||
.stride(stride)
|
.stride(stride)
|
||||||
|
@ -326,7 +326,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.build())
|
.build())
|
||||||
.layer(ZeroPadding1DLayer.builder(zeroPadding).build())
|
.layer(ZeroPadding1DLayer.builder(zeroPadding).build())
|
||||||
.layer(
|
.layer(
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.activation(afn)
|
.activation(afn)
|
||||||
.kernelSize(kernel)
|
.kernelSize(kernel)
|
||||||
.stride(stride)
|
.stride(stride)
|
||||||
|
@ -438,7 +438,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.list()
|
.list()
|
||||||
.layer(
|
.layer(
|
||||||
0,
|
0,
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.activation(afn)
|
.activation(afn)
|
||||||
.kernelSize(kernel)
|
.kernelSize(kernel)
|
||||||
.stride(stride)
|
.stride(stride)
|
||||||
|
@ -447,7 +447,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.build())
|
.build())
|
||||||
.layer(
|
.layer(
|
||||||
1,
|
1,
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.activation(afn)
|
.activation(afn)
|
||||||
.kernelSize(kernel)
|
.kernelSize(kernel)
|
||||||
.stride(stride)
|
.stride(stride)
|
||||||
|
@ -548,7 +548,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.list()
|
.list()
|
||||||
.layer(
|
.layer(
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.kernelSize(2)
|
.kernelSize(2)
|
||||||
.rnnDataFormat(RNNFormat.NCW)
|
.rnnDataFormat(RNNFormat.NCW)
|
||||||
.stride(stride)
|
.stride(stride)
|
||||||
|
@ -562,7 +562,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.pnorm(pnorm)
|
.pnorm(pnorm)
|
||||||
.build())
|
.build())
|
||||||
.layer(
|
.layer(
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.kernelSize(2)
|
.kernelSize(2)
|
||||||
.rnnDataFormat(RNNFormat.NCW)
|
.rnnDataFormat(RNNFormat.NCW)
|
||||||
.stride(stride)
|
.stride(stride)
|
||||||
|
@ -655,7 +655,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.list()
|
.list()
|
||||||
.layer(
|
.layer(
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.kernelSize(k)
|
.kernelSize(k)
|
||||||
.dilation(d)
|
.dilation(d)
|
||||||
.hasBias(hasBias)
|
.hasBias(hasBias)
|
||||||
|
@ -664,7 +664,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
.nOut(convNOut1)
|
.nOut(convNOut1)
|
||||||
.build())
|
.build())
|
||||||
.layer(
|
.layer(
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.kernelSize(k)
|
.kernelSize(k)
|
||||||
.dilation(d)
|
.dilation(d)
|
||||||
.convolutionMode(ConvolutionMode.Causal)
|
.convolutionMode(ConvolutionMode.Causal)
|
||||||
|
|
|
@ -180,7 +180,7 @@ public class DTypeTests extends BaseDL4JTest {
|
||||||
Pooling2D.class, //Alias for SubsamplingLayer
|
Pooling2D.class, //Alias for SubsamplingLayer
|
||||||
Convolution2D.class, //Alias for ConvolutionLayer
|
Convolution2D.class, //Alias for ConvolutionLayer
|
||||||
Pooling1D.class, //Alias for Subsampling1D
|
Pooling1D.class, //Alias for Subsampling1D
|
||||||
Convolution1D.class, //Alias for Convolution1DLayer
|
Convolution1D.class, //Alias for Convolution1D
|
||||||
TensorFlowCnnToFeedForwardPreProcessor.class //Deprecated
|
TensorFlowCnnToFeedForwardPreProcessor.class //Deprecated
|
||||||
));
|
));
|
||||||
|
|
||||||
|
|
|
@ -37,7 +37,7 @@ import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.deeplearning4j.util.ConvolutionUtils;
|
import org.deeplearning4j.util.Convolution2DUtils;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.Timeout;
|
import org.junit.jupiter.api.Timeout;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
@ -1026,7 +1026,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
||||||
} catch (DL4JInvalidInputException e) {
|
} catch (DL4JInvalidInputException e) {
|
||||||
// e.printStackTrace();
|
// e.printStackTrace();
|
||||||
String msg = e.getMessage();
|
String msg = e.getMessage();
|
||||||
assertTrue(msg.contains(ConvolutionUtils.NCHW_NHWC_ERROR_MSG) || msg.contains("input array channels does not match CNN layer configuration"), msg);
|
assertTrue(msg.contains(Convolution2DUtils.NCHW_NHWC_ERROR_MSG) || msg.contains("input array channels does not match CNN layer configuration"), msg);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,7 +36,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
|
import org.deeplearning4j.nn.conf.layers.Convolution1D;
|
||||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
|
@ -921,7 +921,7 @@ public class ConvolutionLayerTest extends BaseDL4JTest {
|
||||||
NeuralNetConfiguration.builder()
|
NeuralNetConfiguration.builder()
|
||||||
.convolutionMode(ConvolutionMode.Same)
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
.layer(
|
.layer(
|
||||||
Convolution1DLayer.builder()
|
Convolution1D.builder()
|
||||||
.nOut(3)
|
.nOut(3)
|
||||||
.kernelSize(2)
|
.kernelSize(2)
|
||||||
.activation(Activation.TANH)
|
.activation(Activation.TANH)
|
||||||
|
@ -975,7 +975,7 @@ public class ConvolutionLayerTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testConv1dCausalAllowed() {
|
public void testConv1dCausalAllowed() {
|
||||||
Convolution1DLayer.builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build();
|
Convolution1D.builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build();
|
||||||
Subsampling1DLayer.builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build();
|
Subsampling1DLayer.builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,7 @@ import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.util.ConvolutionUtils;
|
import org.deeplearning4j.util.Convolution2DUtils;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -346,7 +346,7 @@ public class TestConvolutionModes extends BaseDL4JTest {
|
||||||
assertEquals(2, it.getHeight());
|
assertEquals(2, it.getHeight());
|
||||||
assertEquals(2, it.getWidth());
|
assertEquals(2, it.getWidth());
|
||||||
assertEquals(dOut, it.getChannels());
|
assertEquals(dOut, it.getChannels());
|
||||||
int[] outSize = ConvolutionUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Strict);
|
int[] outSize = Convolution2DUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Strict);
|
||||||
assertEquals(2, outSize[0]);
|
assertEquals(2, outSize[0]);
|
||||||
assertEquals(2, outSize[1]);
|
assertEquals(2, outSize[1]);
|
||||||
|
|
||||||
|
@ -357,7 +357,7 @@ public class TestConvolutionModes extends BaseDL4JTest {
|
||||||
assertEquals(2, it.getHeight());
|
assertEquals(2, it.getHeight());
|
||||||
assertEquals(2, it.getWidth());
|
assertEquals(2, it.getWidth());
|
||||||
assertEquals(dOut, it.getChannels());
|
assertEquals(dOut, it.getChannels());
|
||||||
outSize = ConvolutionUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Truncate);
|
outSize = Convolution2DUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Truncate);
|
||||||
assertEquals(2, outSize[0]);
|
assertEquals(2, outSize[0]);
|
||||||
assertEquals(2, outSize[1]);
|
assertEquals(2, outSize[1]);
|
||||||
|
|
||||||
|
@ -367,7 +367,7 @@ public class TestConvolutionModes extends BaseDL4JTest {
|
||||||
assertEquals(3, it.getHeight());
|
assertEquals(3, it.getHeight());
|
||||||
assertEquals(3, it.getWidth());
|
assertEquals(3, it.getWidth());
|
||||||
assertEquals(dOut, it.getChannels());
|
assertEquals(dOut, it.getChannels());
|
||||||
outSize = ConvolutionUtils.getOutputSize(inData, kernel, stride, null, ConvolutionMode.Same);
|
outSize = Convolution2DUtils.getOutputSize(inData, kernel, stride, null, ConvolutionMode.Same);
|
||||||
assertEquals(3, outSize[0]);
|
assertEquals(3, outSize[0]);
|
||||||
assertEquals(3, outSize[1]);
|
assertEquals(3, outSize[1]);
|
||||||
|
|
||||||
|
@ -397,7 +397,7 @@ public class TestConvolutionModes extends BaseDL4JTest {
|
||||||
System.out.println(e.getMessage());
|
System.out.println(e.getMessage());
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
outSize = ConvolutionUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Strict);
|
outSize = Convolution2DUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Strict);
|
||||||
fail("Exception expected");
|
fail("Exception expected");
|
||||||
} catch (DL4JException e) {
|
} catch (DL4JException e) {
|
||||||
System.out.println(e.getMessage());
|
System.out.println(e.getMessage());
|
||||||
|
@ -409,7 +409,7 @@ public class TestConvolutionModes extends BaseDL4JTest {
|
||||||
assertEquals(1, it.getHeight());
|
assertEquals(1, it.getHeight());
|
||||||
assertEquals(1, it.getWidth());
|
assertEquals(1, it.getWidth());
|
||||||
assertEquals(dOut, it.getChannels());
|
assertEquals(dOut, it.getChannels());
|
||||||
outSize = ConvolutionUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Truncate);
|
outSize = Convolution2DUtils.getOutputSize(inData, kernel, stride, padding, ConvolutionMode.Truncate);
|
||||||
assertEquals(1, outSize[0]);
|
assertEquals(1, outSize[0]);
|
||||||
assertEquals(1, outSize[1]);
|
assertEquals(1, outSize[1]);
|
||||||
|
|
||||||
|
@ -419,7 +419,7 @@ public class TestConvolutionModes extends BaseDL4JTest {
|
||||||
assertEquals(2, it.getHeight());
|
assertEquals(2, it.getHeight());
|
||||||
assertEquals(2, it.getWidth());
|
assertEquals(2, it.getWidth());
|
||||||
assertEquals(dOut, it.getChannels());
|
assertEquals(dOut, it.getChannels());
|
||||||
outSize = ConvolutionUtils.getOutputSize(inData, kernel, stride, null, ConvolutionMode.Same);
|
outSize = Convolution2DUtils.getOutputSize(inData, kernel, stride, null, ConvolutionMode.Same);
|
||||||
assertEquals(2, outSize[0]);
|
assertEquals(2, outSize[0]);
|
||||||
assertEquals(2, outSize[1]);
|
assertEquals(2, outSize[1]);
|
||||||
}
|
}
|
||||||
|
|
|
@ -732,7 +732,7 @@ public class BatchNormalizationTest extends BaseDL4JTest {
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.convolutionMode(ConvolutionMode.Same)
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
.layer(rnn ? LSTM.builder().nOut(3).build() :
|
.layer(rnn ? LSTM.builder().nOut(3).build() :
|
||||||
Convolution1DLayer.builder().kernelSize(3).stride(1).nOut(3).build())
|
Convolution1D.builder().kernelSize(3).stride(1).nOut(3).build())
|
||||||
.layer(BatchNormalization.builder().build())
|
.layer(BatchNormalization.builder().build())
|
||||||
.layer(RnnOutputLayer.builder().nOut(3).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build())
|
.layer(RnnOutputLayer.builder().nOut(3).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build())
|
||||||
.inputType(InputType.recurrent(3))
|
.inputType(InputType.recurrent(3))
|
||||||
|
|
|
@ -52,7 +52,7 @@ public class WeightInitIdentityTest extends BaseDL4JTest {
|
||||||
.graphBuilder()
|
.graphBuilder()
|
||||||
.addInputs(inputName)
|
.addInputs(inputName)
|
||||||
.setOutputs(output)
|
.setOutputs(output)
|
||||||
.layer(conv, Convolution1DLayer.builder(7)
|
.layer(conv, Convolution1D.builder(7)
|
||||||
.convolutionMode(ConvolutionMode.Same)
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
.nOut(input.size(1))
|
.nOut(input.size(1))
|
||||||
.weightInit(new WeightInitIdentity())
|
.weightInit(new WeightInitIdentity())
|
||||||
|
|
|
@ -38,7 +38,7 @@ import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.cuda.BaseCudnnHelper;
|
import org.deeplearning4j.cuda.BaseCudnnHelper;
|
||||||
import org.deeplearning4j.nn.layers.convolution.ConvolutionHelper;
|
import org.deeplearning4j.nn.layers.convolution.ConvolutionHelper;
|
||||||
import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
|
import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
|
||||||
import org.deeplearning4j.util.ConvolutionUtils;
|
import org.deeplearning4j.util.Convolution2DUtils;
|
||||||
import org.nd4j.jita.allocator.Allocator;
|
import org.nd4j.jita.allocator.Allocator;
|
||||||
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
||||||
import org.nd4j.jita.conf.CudaEnvironment;
|
import org.nd4j.jita.conf.CudaEnvironment;
|
||||||
|
@ -681,9 +681,9 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
|
||||||
|
|
||||||
int[] outSize;
|
int[] outSize;
|
||||||
if (convolutionMode == ConvolutionMode.Same) {
|
if (convolutionMode == ConvolutionMode.Same) {
|
||||||
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation, format); //Also performs validation
|
outSize = Convolution2DUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation, format); //Also performs validation
|
||||||
padding = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation);
|
padding = Convolution2DUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation);
|
||||||
int[] padBottomRight = ConvolutionUtils.getSameModeBottomRightPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation);
|
int[] padBottomRight = Convolution2DUtils.getSameModeBottomRightPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation);
|
||||||
if(!Arrays.equals(padding, padBottomRight)){
|
if(!Arrays.equals(padding, padBottomRight)){
|
||||||
/*
|
/*
|
||||||
CuDNN - even as of 7.1 (CUDA 9.1) still doesn't have support for proper SAME mode padding (i.e., asymmetric
|
CuDNN - even as of 7.1 (CUDA 9.1) still doesn't have support for proper SAME mode padding (i.e., asymmetric
|
||||||
|
@ -731,7 +731,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
|
||||||
// CuDNN handle
|
// CuDNN handle
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, padding, convolutionMode, dilation, format); //Also performs validation
|
outSize = Convolution2DUtils.getOutputSize(input, kernel, strides, padding, convolutionMode, dilation, format); //Also performs validation
|
||||||
}
|
}
|
||||||
|
|
||||||
return new CudnnForwardArgs(manualPadBottom, manualPadRight, input, origInput, padding, outSize);
|
return new CudnnForwardArgs(manualPadBottom, manualPadRight, input, origInput, padding, outSize);
|
||||||
|
|
|
@ -42,7 +42,7 @@ import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder;
|
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils;
|
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasOptimizerUtils;
|
import org.deeplearning4j.nn.modelimport.keras.utils.KerasOptimizerUtils;
|
||||||
import org.deeplearning4j.util.ConvolutionUtils;
|
import org.deeplearning4j.util.Convolution2DUtils;
|
||||||
import org.nd4j.common.primitives.Counter;
|
import org.nd4j.common.primitives.Counter;
|
||||||
import org.nd4j.common.primitives.Pair;
|
import org.nd4j.common.primitives.Pair;
|
||||||
import org.nd4j.linalg.learning.config.IUpdater;
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
|
@ -442,8 +442,8 @@ public class KerasModel {
|
||||||
KerasInput kerasInput = (KerasInput) layer;
|
KerasInput kerasInput = (KerasInput) layer;
|
||||||
LayerConfiguration layer1 = layersOrdered.get(kerasLayerIdx + 1).layer;
|
LayerConfiguration layer1 = layersOrdered.get(kerasLayerIdx + 1).layer;
|
||||||
//no dim order, try to pull it from the next layer if there is one
|
//no dim order, try to pull it from the next layer if there is one
|
||||||
if(ConvolutionUtils.layerHasConvolutionLayout(layer1)) {
|
if(Convolution2DUtils.layerHasConvolutionLayout(layer1)) {
|
||||||
CNN2DFormat formatForLayer = ConvolutionUtils.getFormatForLayer(layer1);
|
CNN2DFormat formatForLayer = Convolution2DUtils.getFormatForLayer(layer1);
|
||||||
if(formatForLayer == CNN2DFormat.NCHW) {
|
if(formatForLayer == CNN2DFormat.NCHW) {
|
||||||
dimOrder = KerasLayer.DimOrder.THEANO;
|
dimOrder = KerasLayer.DimOrder.THEANO;
|
||||||
} else if(formatForLayer == CNN2DFormat.NHWC) {
|
} else if(formatForLayer == CNN2DFormat.NHWC) {
|
||||||
|
|
|
@ -23,7 +23,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolutional;
|
||||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
|
import org.deeplearning4j.nn.conf.layers.Convolution1D;
|
||||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||||
|
@ -84,29 +84,29 @@ public class KerasAtrousConvolution1D extends KerasConvolution {
|
||||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||||
|
|
||||||
ConvolutionLayer.ConvolutionLayerBuilder builder = Convolution1DLayer.builder().name(this.name)
|
var builder = Convolution1D.builder().name(this.name)
|
||||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||||
.weightInit(init)
|
.weightInit(init)
|
||||||
.dilation(getDilationRate(layerConfig, 1, conf, true)[0])
|
.dilation(getDilationRate(layerConfig, 1, conf, true)[0])
|
||||||
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
|
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
|
||||||
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
||||||
.kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0])
|
.kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion))
|
||||||
.hasBias(hasBias)
|
.hasBias(hasBias)
|
||||||
.rnnDataFormat(dimOrder == DimOrder.TENSORFLOW ? RNNFormat.NWC : RNNFormat.NCW)
|
.rnnDataFormat(dimOrder == DimOrder.TENSORFLOW ? RNNFormat.NWC : RNNFormat.NCW)
|
||||||
.stride(getStrideFromConfig(layerConfig, 1, conf)[0]);
|
.stride(getStrideFromConfig(layerConfig, 1, conf));
|
||||||
int[] padding = getPaddingFromBorderModeConfig(layerConfig, 1, conf, kerasMajorVersion);
|
int[] padding = getPaddingFromBorderModeConfig(layerConfig, 1, conf, kerasMajorVersion);
|
||||||
if (hasBias)
|
if (hasBias)
|
||||||
builder.biasInit(0.0);
|
builder.biasInit(0.0);
|
||||||
if (padding != null)
|
if (padding != null)
|
||||||
builder.padding(padding[0]);
|
builder.padding(padding);
|
||||||
if (biasConstraint != null)
|
if (biasConstraint != null)
|
||||||
builder.constrainBias(biasConstraint);
|
builder.constrainBias(biasConstraint);
|
||||||
if (weightConstraint != null)
|
if (weightConstraint != null)
|
||||||
builder.constrainWeights(weightConstraint);
|
builder.constrainWeights(weightConstraint);
|
||||||
this.layer = builder.build();
|
this.layer = builder.build();
|
||||||
Convolution1DLayer convolution1DLayer = (Convolution1DLayer) layer;
|
Convolution1D convolution1D = (Convolution1D) layer;
|
||||||
convolution1DLayer.setDefaultValueOverriden(true);
|
convolution1D.setDefaultValueOverriden(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -114,8 +114,8 @@ public class KerasAtrousConvolution1D extends KerasConvolution {
|
||||||
*
|
*
|
||||||
* @return ConvolutionLayer
|
* @return ConvolutionLayer
|
||||||
*/
|
*/
|
||||||
public Convolution1DLayer getAtrousConvolution1D() {
|
public Convolution1D getAtrousConvolution1D() {
|
||||||
return (Convolution1DLayer) this.layer;
|
return (Convolution1D) this.layer;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -28,7 +28,7 @@ import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
|
import org.deeplearning4j.nn.conf.layers.Convolution1D;
|
||||||
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
|
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||||
|
@ -93,7 +93,7 @@ public class KerasConvolution1D extends KerasConvolution {
|
||||||
|
|
||||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||||
Convolution1DLayer.Convolution1DLayerBuilder builder = Convolution1DLayer.builder().name(this.name)
|
var builder = Convolution1D.builder().name(this.name)
|
||||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||||
.weightInit(init)
|
.weightInit(init)
|
||||||
|
@ -125,9 +125,9 @@ public class KerasConvolution1D extends KerasConvolution {
|
||||||
|
|
||||||
this.layer = builder.build();
|
this.layer = builder.build();
|
||||||
//set this in order to infer the dimensional format
|
//set this in order to infer the dimensional format
|
||||||
Convolution1DLayer convolution1DLayer = (Convolution1DLayer) this.layer;
|
Convolution1D convolution1D = (Convolution1D) this.layer;
|
||||||
convolution1DLayer.setDataFormat(dimOrder == DimOrder.TENSORFLOW ? CNN2DFormat.NHWC : CNN2DFormat.NCHW);
|
convolution1D.setDataFormat(dimOrder == DimOrder.TENSORFLOW ? CNN2DFormat.NHWC : CNN2DFormat.NCHW);
|
||||||
convolution1DLayer.setDefaultValueOverriden(true);
|
convolution1D.setDefaultValueOverriden(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -135,8 +135,8 @@ public class KerasConvolution1D extends KerasConvolution {
|
||||||
*
|
*
|
||||||
* @return ConvolutionLayer
|
* @return ConvolutionLayer
|
||||||
*/
|
*/
|
||||||
public Convolution1DLayer getConvolution1DLayer() {
|
public Convolution1D getConvolution1DLayer() {
|
||||||
return (Convolution1DLayer) this.layer;
|
return (Convolution1D) this.layer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -29,11 +29,8 @@ import org.deeplearning4j.gradientcheck.GradientCheckUtil;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.api.layers.IOutputLayer;
|
import org.deeplearning4j.nn.api.layers.IOutputLayer;
|
||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
|
import org.deeplearning4j.nn.conf.layers.Convolution1D;
|
||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.LossLayer;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
|
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive;
|
import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive;
|
||||||
|
@ -656,7 +653,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true,
|
MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true,
|
||||||
true, true, false, null, null);
|
true, true, false, null, null);
|
||||||
Layer l = net.getLayer(0);
|
Layer l = net.getLayer(0);
|
||||||
Convolution1DLayer c1d = (Convolution1DLayer) l.getTrainingConfig();
|
Convolution1D c1d = (Convolution1D) l.getTrainingConfig();
|
||||||
assertEquals(ConvolutionMode.Causal, c1d.getConvolutionMode());
|
assertEquals(ConvolutionMode.Causal, c1d.getConvolutionMode());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,7 +22,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolution;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
import org.deeplearning4j.nn.conf.dropout.Dropout;
|
import org.deeplearning4j.nn.conf.dropout.Dropout;
|
||||||
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
|
import org.deeplearning4j.nn.conf.layers.Convolution1D;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils;
|
import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
|
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
|
||||||
|
@ -97,7 +97,7 @@ public class KerasAtrousConvolution1DTest extends BaseDL4JTest {
|
||||||
config.put(conf.getLAYER_FIELD_BORDER_MODE(), BORDER_MODE_VALID);
|
config.put(conf.getLAYER_FIELD_BORDER_MODE(), BORDER_MODE_VALID);
|
||||||
layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config);
|
layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config);
|
||||||
|
|
||||||
Convolution1DLayer layer = new KerasAtrousConvolution1D(layerConfig).getAtrousConvolution1D();
|
Convolution1D layer = new KerasAtrousConvolution1D(layerConfig).getAtrousConvolution1D();
|
||||||
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
|
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
|
||||||
assertEquals(LAYER_NAME, layer.getName());
|
assertEquals(LAYER_NAME, layer.getName());
|
||||||
assertEquals(INIT_DL4J, layer.getWeightInit());
|
assertEquals(INIT_DL4J, layer.getWeightInit());
|
||||||
|
|
|
@ -22,7 +22,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.convolution;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
import org.deeplearning4j.nn.conf.dropout.Dropout;
|
import org.deeplearning4j.nn.conf.dropout.Dropout;
|
||||||
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
|
import org.deeplearning4j.nn.conf.layers.Convolution1D;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils;
|
import org.deeplearning4j.nn.modelimport.keras.KerasTestUtils;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
|
import org.deeplearning4j.nn.modelimport.keras.config.Keras1LayerConfiguration;
|
||||||
|
@ -119,7 +119,7 @@ public class KerasConvolution1DTest extends BaseDL4JTest {
|
||||||
config.put(conf.getLAYER_FIELD_BORDER_MODE(), BORDER_MODE_VALID);
|
config.put(conf.getLAYER_FIELD_BORDER_MODE(), BORDER_MODE_VALID);
|
||||||
layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config);
|
layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config);
|
||||||
|
|
||||||
Convolution1DLayer layer = new KerasConvolution1D(layerConfig).getConvolution1DLayer();
|
Convolution1D layer = new KerasConvolution1D(layerConfig).getConvolution1DLayer();
|
||||||
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
|
assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString());
|
||||||
assertEquals(LAYER_NAME, layer.getName());
|
assertEquals(LAYER_NAME, layer.getName());
|
||||||
assertEquals(INIT_DL4J, layer.getWeightInit());
|
assertEquals(INIT_DL4J, layer.getWeightInit());
|
||||||
|
|
|
@ -22,8 +22,6 @@
|
||||||
package net.brutex.ai.dnn.api;
|
package net.brutex.ai.dnn.api;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.List;
|
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
||||||
|
|
||||||
public interface INeuralNetworkConfiguration extends Serializable, Cloneable {
|
public interface INeuralNetworkConfiguration extends Serializable, Cloneable {
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,6 @@ package net.brutex.ai.dnn.api;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
|
||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A fluent API to configure and create artificial neural networks
|
* A fluent API to configure and create artificial neural networks
|
||||||
|
|
|
@ -23,7 +23,6 @@ package net.brutex.ai.dnn.networks;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
|
@ -33,7 +32,6 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Artificial Neural Network An artificial neural network (1) takes some input data, and (2)
|
* Artificial Neural Network An artificial neural network (1) takes some input data, and (2)
|
||||||
* transforms this input data by calculating a weighted sum over the inputs and (3) applies a
|
* transforms this input data by calculating a weighted sum over the inputs and (3) applies a
|
||||||
|
|
|
@ -20,6 +20,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping;
|
package org.deeplearning4j.earlystopping;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
|
@ -30,11 +34,6 @@ import org.deeplearning4j.earlystopping.termination.IterationTerminationConditio
|
||||||
import org.deeplearning4j.exception.DL4JInvalidConfigException;
|
import org.deeplearning4j.exception.DL4JInvalidConfigException;
|
||||||
import org.nd4j.common.function.Supplier;
|
import org.nd4j.common.function.Supplier;
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
public class EarlyStoppingConfiguration<T extends IModel> implements Serializable {
|
public class EarlyStoppingConfiguration<T extends IModel> implements Serializable {
|
||||||
|
|
|
@ -20,16 +20,15 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping;
|
package org.deeplearning4j.earlystopping;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonSubTypes;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.Serializable;
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver;
|
import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver;
|
||||||
import org.deeplearning4j.earlystopping.saver.LocalFileGraphSaver;
|
import org.deeplearning4j.earlystopping.saver.LocalFileGraphSaver;
|
||||||
import org.deeplearning4j.earlystopping.saver.LocalFileModelSaver;
|
import org.deeplearning4j.earlystopping.saver.LocalFileModelSaver;
|
||||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonSubTypes;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.Serializable;
|
|
||||||
|
|
||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
@JsonSubTypes(value = {@JsonSubTypes.Type(value = InMemoryModelSaver.class, name = "InMemoryModelSaver"),
|
@JsonSubTypes(value = {@JsonSubTypes.Type(value = InMemoryModelSaver.class, name = "InMemoryModelSaver"),
|
||||||
|
|
|
@ -20,11 +20,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping;
|
package org.deeplearning4j.earlystopping;
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import lombok.Data;
|
||||||
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class EarlyStoppingResult<T extends IModel> implements Serializable {
|
public class EarlyStoppingResult<T extends IModel> implements Serializable {
|
||||||
|
|
|
@ -20,10 +20,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.saver;
|
package org.deeplearning4j.earlystopping.saver;
|
||||||
|
|
||||||
import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver;
|
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
|
import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver;
|
||||||
|
|
||||||
public class InMemoryModelSaver<T extends IModel> implements EarlyStoppingModelSaver<T> {
|
public class InMemoryModelSaver<T extends IModel> implements EarlyStoppingModelSaver<T> {
|
||||||
|
|
||||||
|
|
|
@ -20,15 +20,14 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.saver;
|
package org.deeplearning4j.earlystopping.saver;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.charset.Charset;
|
||||||
import org.apache.commons.io.FilenameUtils;
|
import org.apache.commons.io.FilenameUtils;
|
||||||
import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver;
|
import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.util.ModelSerializer;
|
import org.deeplearning4j.util.ModelSerializer;
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.nio.charset.Charset;
|
|
||||||
|
|
||||||
public class LocalFileGraphSaver implements EarlyStoppingModelSaver<ComputationGraph> {
|
public class LocalFileGraphSaver implements EarlyStoppingModelSaver<ComputationGraph> {
|
||||||
|
|
||||||
private static final String BEST_GRAPH_BIN = "bestGraph.bin";
|
private static final String BEST_GRAPH_BIN = "bestGraph.bin";
|
||||||
|
|
|
@ -20,15 +20,14 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.saver;
|
package org.deeplearning4j.earlystopping.saver;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.charset.Charset;
|
||||||
import org.apache.commons.io.FilenameUtils;
|
import org.apache.commons.io.FilenameUtils;
|
||||||
import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver;
|
import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.util.ModelSerializer;
|
import org.deeplearning4j.util.ModelSerializer;
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.nio.charset.Charset;
|
|
||||||
|
|
||||||
public class LocalFileModelSaver implements EarlyStoppingModelSaver<MultiLayerNetwork> {
|
public class LocalFileModelSaver implements EarlyStoppingModelSaver<MultiLayerNetwork> {
|
||||||
|
|
||||||
private static final String BEST_MODEL_BIN = "bestModel.bin";
|
private static final String BEST_MODEL_BIN = "bestModel.bin";
|
||||||
|
|
|
@ -26,11 +26,11 @@ import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder;
|
import org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
||||||
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
|
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
|
||||||
|
|
||||||
public class AutoencoderScoreCalculator extends BaseScoreCalculator<IModel> {
|
public class AutoencoderScoreCalculator extends BaseScoreCalculator<IModel> {
|
||||||
|
|
||||||
|
|
|
@ -20,8 +20,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.scorecalc;
|
package org.deeplearning4j.earlystopping.scorecalc;
|
||||||
|
|
||||||
import org.deeplearning4j.earlystopping.scorecalc.base.BaseScoreCalculator;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
|
import org.deeplearning4j.earlystopping.scorecalc.base.BaseScoreCalculator;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -29,7 +30,6 @@ import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
public class DataSetLossCalculator extends BaseScoreCalculator<IModel> {
|
public class DataSetLossCalculator extends BaseScoreCalculator<IModel> {
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.scorecalc;
|
package org.deeplearning4j.earlystopping.scorecalc;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
|
@ -27,8 +29,6 @@ import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
@Deprecated
|
@Deprecated
|
||||||
|
|
|
@ -20,12 +20,11 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.scorecalc;
|
package org.deeplearning4j.earlystopping.scorecalc;
|
||||||
|
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||||
import com.fasterxml.jackson.annotation.JsonSubTypes;
|
import com.fasterxml.jackson.annotation.JsonSubTypes;
|
||||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
|
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
|
|
|
@ -26,11 +26,11 @@ import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder;
|
import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
||||||
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
|
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
|
||||||
|
|
||||||
public class VAEReconErrorScoreCalculator extends BaseScoreCalculator<IModel> {
|
public class VAEReconErrorScoreCalculator extends BaseScoreCalculator<IModel> {
|
||||||
|
|
||||||
|
|
|
@ -20,9 +20,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.scorecalc.base;
|
package org.deeplearning4j.earlystopping.scorecalc.base;
|
||||||
|
|
||||||
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
|
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
|
||||||
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
|
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.nd4j.evaluation.IEvaluation;
|
import org.nd4j.evaluation.IEvaluation;
|
||||||
|
|
|
@ -21,8 +21,8 @@
|
||||||
package org.deeplearning4j.earlystopping.scorecalc.base;
|
package org.deeplearning4j.earlystopping.scorecalc.base;
|
||||||
|
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
|
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
|
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
|
|
|
@ -20,8 +20,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.termination;
|
package org.deeplearning4j.earlystopping.termination;
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class BestScoreEpochTerminationCondition implements EpochTerminationCondition {
|
public class BestScoreEpochTerminationCondition implements EpochTerminationCondition {
|
||||||
|
|
|
@ -22,9 +22,7 @@ package org.deeplearning4j.earlystopping.termination;
|
||||||
|
|
||||||
|
|
||||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||||
import com.fasterxml.jackson.annotation.JsonSubTypes;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
|
|
|
@ -22,7 +22,6 @@ package org.deeplearning4j.earlystopping.termination;
|
||||||
|
|
||||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
|
|
|
@ -20,10 +20,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.termination;
|
package org.deeplearning4j.earlystopping.termination;
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonCreator;
|
import com.fasterxml.jackson.annotation.JsonCreator;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
@Data
|
@Data
|
||||||
|
|
|
@ -20,8 +20,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.termination;
|
package org.deeplearning4j.earlystopping.termination;
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class MaxScoreIterationTerminationCondition implements IterationTerminationCondition {
|
public class MaxScoreIterationTerminationCondition implements IterationTerminationCondition {
|
||||||
|
|
|
@ -20,10 +20,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.termination;
|
package org.deeplearning4j.earlystopping.termination;
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
/**Terminate training based on max time.
|
/**Terminate training based on max time.
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -20,9 +20,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.termination;
|
package org.deeplearning4j.earlystopping.termination;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Data
|
@Data
|
||||||
|
|
|
@ -20,6 +20,12 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.trainer;
|
package org.deeplearning4j.earlystopping.trainer;
|
||||||
|
|
||||||
|
import java.io.FileNotFoundException;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.Iterator;
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.Map;
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
|
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
|
||||||
import org.deeplearning4j.earlystopping.EarlyStoppingResult;
|
import org.deeplearning4j.earlystopping.EarlyStoppingResult;
|
||||||
|
@ -40,13 +46,6 @@ import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import java.io.FileNotFoundException;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.Collection;
|
|
||||||
import java.util.Iterator;
|
|
||||||
import java.util.LinkedHashMap;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
public abstract class BaseEarlyStoppingTrainer<T extends IModel> implements IEarlyStoppingTrainer<T> {
|
public abstract class BaseEarlyStoppingTrainer<T extends IModel> implements IEarlyStoppingTrainer<T> {
|
||||||
|
|
||||||
private static final Logger log = LoggerFactory.getLogger(BaseEarlyStoppingTrainer.class);
|
private static final Logger log = LoggerFactory.getLogger(BaseEarlyStoppingTrainer.class);
|
||||||
|
|
|
@ -20,7 +20,6 @@
|
||||||
|
|
||||||
package org.deeplearning4j.earlystopping.trainer;
|
package org.deeplearning4j.earlystopping.trainer;
|
||||||
|
|
||||||
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
|
|
||||||
import org.deeplearning4j.datasets.iterator.impl.SingletonDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.SingletonDataSetIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator;
|
||||||
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
|
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
|
||||||
|
|
|
@ -20,6 +20,13 @@
|
||||||
|
|
||||||
package org.deeplearning4j.eval;
|
package org.deeplearning4j.eval;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonAutoDetect;
|
||||||
|
import com.fasterxml.jackson.databind.DeserializationFeature;
|
||||||
|
import com.fasterxml.jackson.databind.MapperFeature;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import com.fasterxml.jackson.databind.SerializationFeature;
|
||||||
|
import com.fasterxml.jackson.databind.module.SimpleModule;
|
||||||
|
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.nd4j.common.primitives.AtomicBoolean;
|
import org.nd4j.common.primitives.AtomicBoolean;
|
||||||
|
@ -28,14 +35,6 @@ import org.nd4j.common.primitives.serde.JsonDeserializerAtomicBoolean;
|
||||||
import org.nd4j.common.primitives.serde.JsonDeserializerAtomicDouble;
|
import org.nd4j.common.primitives.serde.JsonDeserializerAtomicDouble;
|
||||||
import org.nd4j.common.primitives.serde.JsonSerializerAtomicBoolean;
|
import org.nd4j.common.primitives.serde.JsonSerializerAtomicBoolean;
|
||||||
import org.nd4j.common.primitives.serde.JsonSerializerAtomicDouble;
|
import org.nd4j.common.primitives.serde.JsonSerializerAtomicDouble;
|
||||||
import com.fasterxml.jackson.annotation.JsonAutoDetect;
|
|
||||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
|
||||||
import com.fasterxml.jackson.databind.DeserializationFeature;
|
|
||||||
import com.fasterxml.jackson.databind.MapperFeature;
|
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
|
||||||
import com.fasterxml.jackson.databind.SerializationFeature;
|
|
||||||
import com.fasterxml.jackson.databind.module.SimpleModule;
|
|
||||||
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
|
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
@EqualsAndHashCode(callSuper = false)
|
@EqualsAndHashCode(callSuper = false)
|
||||||
|
|
|
@ -20,15 +20,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.eval;
|
package org.deeplearning4j.eval;
|
||||||
|
|
||||||
import com.google.common.collect.HashMultiset;
|
|
||||||
import com.google.common.collect.Multiset;
|
|
||||||
import lombok.Getter;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public class ConfusionMatrix<T extends Comparable<? super T>> extends org.nd4j.evaluation.classification.ConfusionMatrix<T> {
|
public class ConfusionMatrix<T extends Comparable<? super T>> extends org.nd4j.evaluation.classification.ConfusionMatrix<T> {
|
||||||
|
|
|
@ -20,14 +20,11 @@
|
||||||
|
|
||||||
package org.deeplearning4j.eval;
|
package org.deeplearning4j.eval;
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
|
||||||
import lombok.NonNull;
|
|
||||||
import org.nd4j.evaluation.EvaluationAveraging;
|
|
||||||
import org.nd4j.evaluation.IEvaluation;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@Deprecated
|
@Deprecated
|
||||||
|
|
|
@ -20,9 +20,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.eval;
|
package org.deeplearning4j.eval;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
@Getter
|
@Getter
|
||||||
|
|
|
@ -20,11 +20,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.eval;
|
package org.deeplearning4j.eval;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
@Data
|
@Data
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
|
|
|
@ -20,10 +20,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.eval.curves;
|
package org.deeplearning4j.eval.curves;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import org.nd4j.evaluation.curves.BaseHistogram;
|
import org.nd4j.evaluation.curves.BaseHistogram;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
@Data
|
@Data
|
||||||
|
|
|
@ -20,13 +20,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.eval.curves;
|
package org.deeplearning4j.eval.curves;
|
||||||
|
|
||||||
import com.google.common.base.Preconditions;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
@Data
|
@Data
|
||||||
|
|
|
@ -20,8 +20,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.eval.curves;
|
package org.deeplearning4j.eval.curves;
|
||||||
|
|
||||||
import lombok.NonNull;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
import lombok.NonNull;
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public class ReliabilityDiagram extends org.nd4j.evaluation.curves.ReliabilityDiagram {
|
public class ReliabilityDiagram extends org.nd4j.evaluation.curves.ReliabilityDiagram {
|
||||||
|
|
|
@ -20,10 +20,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.eval.curves;
|
package org.deeplearning4j.eval.curves;
|
||||||
|
|
||||||
import com.google.common.base.Preconditions;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
@Data
|
@Data
|
||||||
|
|
|
@ -20,7 +20,6 @@
|
||||||
|
|
||||||
package org.deeplearning4j.eval.meta;
|
package org.deeplearning4j.eval.meta;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.adapters;
|
package org.deeplearning4j.nn.adapters;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
@ -32,8 +33,6 @@ import org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Builder
|
@Builder
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
|
|
|
@ -21,7 +21,6 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.api;
|
package org.deeplearning4j.nn.api;
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
||||||
|
|
||||||
|
|
|
@ -20,14 +20,12 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.api;
|
package org.deeplearning4j.nn.api;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.DataSet;
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
|
|
||||||
public interface Classifier extends IModel {
|
public interface Classifier extends IModel {
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -20,13 +20,12 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.api;
|
package org.deeplearning4j.nn.api;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
import org.deeplearning4j.nn.conf.GradientNormalization;
|
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.learning.config.IUpdater;
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
import org.nd4j.linalg.learning.regularization.Regularization;
|
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public interface ITraininableLayerConfiguration {
|
public interface ITraininableLayerConfiguration {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
package org.deeplearning4j.nn.api;
|
package org.deeplearning4j.nn.api;
|
||||||
|
|
||||||
|
|
||||||
import java.util.Map;
|
import java.io.Serializable;
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
import org.deeplearning4j.nn.conf.CacheMode;
|
import org.deeplearning4j.nn.conf.CacheMode;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
@ -29,10 +29,8 @@ import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.nn.layers.LayerHelper;
|
import org.deeplearning4j.nn.layers.LayerHelper;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.common.primitives.Pair;
|
import org.nd4j.common.primitives.Pair;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import java.io.Serializable;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A layer is the highest-level building block in deep learning. A layer is a container that usually
|
* A layer is the highest-level building block in deep learning. A layer is a container that usually
|
||||||
|
|
|
@ -20,13 +20,12 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.api;
|
package org.deeplearning4j.nn.api;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Param initializer for a layer
|
* Param initializer for a layer
|
||||||
*
|
*
|
||||||
|
|
|
@ -20,11 +20,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.api;
|
package org.deeplearning4j.nn.api;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Update the model
|
* Update the model
|
||||||
|
|
|
@ -22,8 +22,8 @@ package org.deeplearning4j.nn.api.layers;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.api.Classifier;
|
import org.deeplearning4j.nn.api.Classifier;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
public interface IOutputLayer extends Layer, Classifier {
|
public interface IOutputLayer extends Layer, Classifier {
|
||||||
|
|
||||||
|
|
|
@ -20,11 +20,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.api.layers;
|
package org.deeplearning4j.nn.api.layers;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
|
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
public interface LayerConstraint extends Cloneable, Serializable {
|
public interface LayerConstraint extends Cloneable, Serializable {
|
||||||
|
|
|
@ -20,13 +20,12 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.api.layers;
|
package org.deeplearning4j.nn.api.layers;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.common.primitives.Pair;
|
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
|
import org.nd4j.common.primitives.Pair;
|
||||||
import java.util.Map;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
public interface RecurrentLayer extends Layer {
|
public interface RecurrentLayer extends Layer {
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,12 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf;
|
package org.deeplearning4j.nn.conf;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.databind.JsonNode;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import com.fasterxml.jackson.databind.exc.InvalidTypeIdException;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.Serializable;
|
||||||
|
import java.util.*;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||||
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
||||||
|
@ -43,16 +49,9 @@ import org.nd4j.common.base.Preconditions;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.activations.IActivation;
|
import org.nd4j.linalg.activations.IActivation;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import com.fasterxml.jackson.databind.JsonNode;
|
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
|
||||||
import com.fasterxml.jackson.databind.exc.InvalidTypeIdException;
|
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@EqualsAndHashCode(exclude = {"trainingWorkspaceMode", "inferenceWorkspaceMode", "cacheMode", "topologicalOrder", "topologicalOrderStr"})
|
@EqualsAndHashCode(exclude = {"trainingWorkspaceMode", "inferenceWorkspaceMode", "cacheMode", "topologicalOrder", "topologicalOrderStr"})
|
||||||
@AllArgsConstructor(access = AccessLevel.PRIVATE)
|
@AllArgsConstructor(access = AccessLevel.PRIVATE)
|
||||||
|
@ -161,7 +160,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable {
|
||||||
*/
|
*/
|
||||||
public static ComputationGraphConfiguration fromJson(String json) {
|
public static ComputationGraphConfiguration fromJson(String json) {
|
||||||
//As per NeuralNetConfiguration.fromJson()
|
//As per NeuralNetConfiguration.fromJson()
|
||||||
ObjectMapper mapper =CavisMapper.getMapper(CavisMapper.Type.JSON);
|
ObjectMapper mapper = CavisMapper.getMapper(CavisMapper.Type.JSON);
|
||||||
ComputationGraphConfiguration conf;
|
ComputationGraphConfiguration conf;
|
||||||
try {
|
try {
|
||||||
conf = mapper.readValue(json, ComputationGraphConfiguration.class);
|
conf = mapper.readValue(json, ComputationGraphConfiguration.class);
|
||||||
|
|
|
@ -19,10 +19,10 @@
|
||||||
*/
|
*/
|
||||||
package org.deeplearning4j.nn.conf;
|
package org.deeplearning4j.nn.conf;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.conf.serde.format.DataFormatDeserializer;
|
|
||||||
import org.deeplearning4j.nn.conf.serde.format.DataFormatSerializer;
|
|
||||||
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
|
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
|
||||||
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
|
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
|
||||||
|
import org.deeplearning4j.nn.conf.serde.format.DataFormatDeserializer;
|
||||||
|
import org.deeplearning4j.nn.conf.serde.format.DataFormatSerializer;
|
||||||
|
|
||||||
@JsonSerialize(using = DataFormatSerializer.class)
|
@JsonSerialize(using = DataFormatSerializer.class)
|
||||||
@JsonDeserialize(using = DataFormatDeserializer.class)
|
@JsonDeserialize(using = DataFormatDeserializer.class)
|
||||||
|
|
|
@ -21,14 +21,13 @@
|
||||||
package org.deeplearning4j.nn.conf;
|
package org.deeplearning4j.nn.conf;
|
||||||
|
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
import java.io.Serializable;
|
||||||
import org.deeplearning4j.nn.api.MaskState;
|
import org.deeplearning4j.nn.api.MaskState;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.common.primitives.Pair;
|
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
import org.nd4j.common.primitives.Pair;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import java.io.Serializable;
|
|
||||||
|
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
public interface InputPreProcessor extends Serializable, Cloneable {
|
public interface InputPreProcessor extends Serializable, Cloneable {
|
||||||
|
|
|
@ -23,12 +23,9 @@ package org.deeplearning4j.nn.conf;
|
||||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
import com.fasterxml.jackson.databind.JsonNode;
|
import java.util.*;
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
|
||||||
import com.fasterxml.jackson.databind.node.ArrayNode;
|
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import lombok.experimental.SuperBuilder;
|
import lombok.experimental.SuperBuilder;
|
||||||
import lombok.extern.jackson.Jacksonized;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import net.brutex.ai.dnn.api.INeuralNetworkConfiguration;
|
import net.brutex.ai.dnn.api.INeuralNetworkConfiguration;
|
||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
|
@ -37,11 +34,8 @@ import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||||
import org.deeplearning4j.nn.conf.dropout.Dropout;
|
import org.deeplearning4j.nn.conf.dropout.Dropout;
|
||||||
import org.deeplearning4j.nn.conf.dropout.IDropout;
|
import org.deeplearning4j.nn.conf.dropout.IDropout;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.serde.CavisMapper;
|
|
||||||
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
|
||||||
import org.deeplearning4j.nn.conf.stepfunctions.StepFunction;
|
import org.deeplearning4j.nn.conf.stepfunctions.StepFunction;
|
||||||
import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise;
|
import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise;
|
||||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||||
|
@ -50,7 +44,6 @@ import org.deeplearning4j.nn.weights.WeightInitDistribution;
|
||||||
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
||||||
import org.deeplearning4j.util.NetworkUtils;
|
import org.deeplearning4j.util.NetworkUtils;
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
|
||||||
import org.nd4j.linalg.activations.IActivation;
|
import org.nd4j.linalg.activations.IActivation;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.learning.config.IUpdater;
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
|
@ -60,9 +53,6 @@ import org.nd4j.linalg.learning.regularization.L2Regularization;
|
||||||
import org.nd4j.linalg.learning.regularization.Regularization;
|
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||||
import org.nd4j.linalg.learning.regularization.WeightDecay;
|
import org.nd4j.linalg.learning.regularization.WeightDecay;
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Deeplearning4j is a domain-specific language to configure deep neural networks, which are made of
|
* Deeplearning4j is a domain-specific language to configure deep neural networks, which are made of
|
||||||
* multiple layers. Everything starts with a NeuralNetConfiguration, which organizes those layers
|
* multiple layers. Everything starts with a NeuralNetConfiguration, which organizes those layers
|
||||||
|
|
|
@ -326,7 +326,7 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
LayerConfiguration layer = getFlattenedLayerConfigurations().get(i - 1);
|
LayerConfiguration layer = getFlattenedLayerConfigurations().get(i - 1);
|
||||||
// convolution 1d is an edge case where it has rnn input type but the filters
|
// convolution 1d is an edge case where it has rnn input type but the filters
|
||||||
// should be the output
|
// should be the output
|
||||||
if (layer instanceof Convolution1DLayer) {
|
if (layer instanceof Convolution1D) {
|
||||||
if (l instanceof DenseLayer && getInputType() instanceof InputType.InputTypeRecurrent) {
|
if (l instanceof DenseLayer && getInputType() instanceof InputType.InputTypeRecurrent) {
|
||||||
FeedForwardLayer feedForwardLayer = (FeedForwardLayer) l;
|
FeedForwardLayer feedForwardLayer = (FeedForwardLayer) l;
|
||||||
if (getInputType() instanceof InputType.InputTypeRecurrent) {
|
if (getInputType() instanceof InputType.InputTypeRecurrent) {
|
||||||
|
|
|
@ -20,6 +20,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.constraint;
|
package org.deeplearning4j.nn.conf.constraint;
|
||||||
|
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import org.apache.commons.lang3.ArrayUtils;
|
import org.apache.commons.lang3.ArrayUtils;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
|
@ -27,11 +30,6 @@ import org.deeplearning4j.nn.api.ParamInitializer;
|
||||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Set;
|
|
||||||
|
|
||||||
|
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode
|
||||||
@Data
|
@Data
|
||||||
|
|
|
@ -20,6 +20,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.constraint;
|
package org.deeplearning4j.nn.conf.constraint;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.Set;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -27,9 +29,6 @@ import org.nd4j.linalg.factory.Broadcast;
|
||||||
import org.nd4j.linalg.indexing.BooleanIndexing;
|
import org.nd4j.linalg.indexing.BooleanIndexing;
|
||||||
import org.nd4j.linalg.indexing.conditions.Conditions;
|
import org.nd4j.linalg.indexing.conditions.Conditions;
|
||||||
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.Set;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
public class MaxNormConstraint extends BaseConstraint {
|
public class MaxNormConstraint extends BaseConstraint {
|
||||||
|
|
|
@ -20,6 +20,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.constraint;
|
package org.deeplearning4j.nn.conf.constraint;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.Set;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -27,11 +29,6 @@ import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.factory.Broadcast;
|
import org.nd4j.linalg.factory.Broadcast;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.BooleanIndexing;
|
|
||||||
import org.nd4j.linalg.indexing.conditions.Conditions;
|
|
||||||
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.Set;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
|
|
|
@ -20,14 +20,13 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.constraint;
|
package org.deeplearning4j.nn.conf.constraint;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.Set;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Broadcast;
|
import org.nd4j.linalg.factory.Broadcast;
|
||||||
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.Set;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
public class UnitNormConstraint extends BaseConstraint {
|
public class UnitNormConstraint extends BaseConstraint {
|
||||||
|
|
|
@ -20,10 +20,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.distribution;
|
package org.deeplearning4j.nn.conf.distribution;
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.EqualsAndHashCode;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonCreator;
|
import com.fasterxml.jackson.annotation.JsonCreator;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@EqualsAndHashCode(callSuper = false)
|
@EqualsAndHashCode(callSuper = false)
|
||||||
|
|
|
@ -20,12 +20,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.distribution;
|
package org.deeplearning4j.nn.conf.distribution;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.conf.distribution.serde.LegacyDistributionHelper;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
|
||||||
|
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, property = "@class")
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, property = "@class")
|
||||||
public abstract class Distribution implements Serializable, Cloneable {
|
public abstract class Distribution implements Serializable, Cloneable {
|
||||||
|
|
||||||
|
|
|
@ -20,10 +20,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.distribution;
|
package org.deeplearning4j.nn.conf.distribution;
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.EqualsAndHashCode;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonCreator;
|
import com.fasterxml.jackson.annotation.JsonCreator;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A log-normal distribution, with two parameters: mean and standard deviation.
|
* A log-normal distribution, with two parameters: mean and standard deviation.
|
||||||
|
|
|
@ -20,11 +20,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.distribution;
|
package org.deeplearning4j.nn.conf.distribution;
|
||||||
|
|
||||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.EqualsAndHashCode;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonCreator;
|
import com.fasterxml.jackson.annotation.JsonCreator;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A normal (Gaussian) distribution, with two parameters: mean and standard deviation
|
* A normal (Gaussian) distribution, with two parameters: mean and standard deviation
|
||||||
|
|
|
@ -20,10 +20,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.distribution;
|
package org.deeplearning4j.nn.conf.distribution;
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.EqualsAndHashCode;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonCreator;
|
import com.fasterxml.jackson.annotation.JsonCreator;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Orthogonal distribution, with gain parameter.<br>
|
* Orthogonal distribution, with gain parameter.<br>
|
||||||
|
|
|
@ -20,10 +20,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.distribution;
|
package org.deeplearning4j.nn.conf.distribution;
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.EqualsAndHashCode;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonCreator;
|
import com.fasterxml.jackson.annotation.JsonCreator;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
|
||||||
@EqualsAndHashCode(callSuper = false)
|
@EqualsAndHashCode(callSuper = false)
|
||||||
@Data
|
@Data
|
||||||
|
|
|
@ -20,12 +20,12 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.distribution;
|
package org.deeplearning4j.nn.conf.distribution;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonCreator;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import org.apache.commons.math3.exception.NumberIsTooLargeException;
|
import org.apache.commons.math3.exception.NumberIsTooLargeException;
|
||||||
import org.apache.commons.math3.exception.util.LocalizedFormats;
|
import org.apache.commons.math3.exception.util.LocalizedFormats;
|
||||||
import com.fasterxml.jackson.annotation.JsonCreator;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A uniform distribution, with two parameters: lower and upper - i.e., U(lower,upper)
|
* A uniform distribution, with two parameters: lower and upper - i.e., U(lower,upper)
|
||||||
|
|
|
@ -20,15 +20,13 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.distribution.serde;
|
package org.deeplearning4j.nn.conf.distribution.serde;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.conf.distribution.*;
|
|
||||||
import com.fasterxml.jackson.core.JsonParseException;
|
import com.fasterxml.jackson.core.JsonParseException;
|
||||||
import com.fasterxml.jackson.core.JsonParser;
|
import com.fasterxml.jackson.core.JsonParser;
|
||||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
|
||||||
import com.fasterxml.jackson.databind.DeserializationContext;
|
import com.fasterxml.jackson.databind.DeserializationContext;
|
||||||
import com.fasterxml.jackson.databind.JsonDeserializer;
|
import com.fasterxml.jackson.databind.JsonDeserializer;
|
||||||
import com.fasterxml.jackson.databind.JsonNode;
|
import com.fasterxml.jackson.databind.JsonNode;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import org.deeplearning4j.nn.conf.distribution.*;
|
||||||
|
|
||||||
public class LegacyDistributionDeserializer extends JsonDeserializer<Distribution> {
|
public class LegacyDistributionDeserializer extends JsonDeserializer<Distribution> {
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -20,8 +20,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.distribution.serde;
|
package org.deeplearning4j.nn.conf.distribution.serde;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
|
||||||
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
|
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
|
||||||
|
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||||
|
|
||||||
@JsonDeserialize(using = LegacyDistributionDeserializer.class)
|
@JsonDeserialize(using = LegacyDistributionDeserializer.class)
|
||||||
public class LegacyDistributionHelper extends Distribution {
|
public class LegacyDistributionHelper extends Distribution {
|
||||||
|
|
|
@ -20,6 +20,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.dropout;
|
package org.deeplearning4j.nn.conf.dropout;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
|
@ -32,8 +34,6 @@ import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
|
||||||
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
|
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.schedule.ISchedule;
|
import org.nd4j.linalg.schedule.ISchedule;
|
||||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@EqualsAndHashCode(exclude = {"lastPValue","alphaPrime","a","b", "mask"})
|
@EqualsAndHashCode(exclude = {"lastPValue","alphaPrime","a","b", "mask"})
|
||||||
|
|
|
@ -20,6 +20,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.dropout;
|
package org.deeplearning4j.nn.conf.dropout;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
@ -36,8 +38,6 @@ import org.nd4j.linalg.api.ops.random.impl.DropOutInverted;
|
||||||
import org.nd4j.linalg.exception.ND4JOpProfilerException;
|
import org.nd4j.linalg.exception.ND4JOpProfilerException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.schedule.ISchedule;
|
import org.nd4j.linalg.schedule.ISchedule;
|
||||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@JsonIgnoreProperties({"mask", "helper", "helperCountFail", "initializedHelper"})
|
@JsonIgnoreProperties({"mask", "helper", "helperCountFail", "initializedHelper"})
|
||||||
|
|
|
@ -20,6 +20,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.dropout;
|
package org.deeplearning4j.nn.conf.dropout;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||||
|
@ -30,8 +32,6 @@ import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
|
||||||
import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
|
import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.schedule.ISchedule;
|
import org.nd4j.linalg.schedule.ISchedule;
|
||||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@JsonIgnoreProperties({"noise"})
|
@JsonIgnoreProperties({"noise"})
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.dropout;
|
package org.deeplearning4j.nn.conf.dropout;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -27,7 +28,6 @@ import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
|
||||||
import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
|
import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.schedule.ISchedule;
|
import org.nd4j.linalg.schedule.ISchedule;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class GaussianNoise implements IDropout {
|
public class GaussianNoise implements IDropout {
|
||||||
|
|
|
@ -20,11 +20,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.dropout;
|
package org.deeplearning4j.nn.conf.dropout;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
import java.io.Serializable;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
|
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
public interface IDropout extends Serializable, Cloneable {
|
public interface IDropout extends Serializable, Cloneable {
|
||||||
|
|
|
@ -20,6 +20,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.dropout;
|
package org.deeplearning4j.nn.conf.dropout;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
|
@ -31,8 +33,6 @@ import org.nd4j.linalg.api.ops.random.impl.DropOutInverted;
|
||||||
import org.nd4j.linalg.factory.Broadcast;
|
import org.nd4j.linalg.factory.Broadcast;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.schedule.ISchedule;
|
import org.nd4j.linalg.schedule.ISchedule;
|
||||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@JsonIgnoreProperties({"mask"})
|
@JsonIgnoreProperties({"mask"})
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
package org.deeplearning4j.nn.conf.graph;
|
package org.deeplearning4j.nn.conf.graph;
|
||||||
|
|
||||||
import com.google.common.base.Preconditions;
|
import com.google.common.base.Preconditions;
|
||||||
|
import java.util.Map;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import org.deeplearning4j.nn.api.MaskState;
|
import org.deeplearning4j.nn.api.MaskState;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
@ -30,12 +31,10 @@ import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.nn.weights.WeightInitUtil;
|
import org.deeplearning4j.nn.weights.WeightInitUtil;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.common.primitives.Pair;
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.common.primitives.Pair;
|
|
||||||
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
@Data
|
@Data
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.graph;
|
package org.deeplearning4j.nn.conf.graph;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
@ -30,7 +31,6 @@ import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class ElementWiseVertex extends GraphVertex {
|
public class ElementWiseVertex extends GraphVertex {
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.graph;
|
package org.deeplearning4j.nn.conf.graph;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
@ -28,7 +29,6 @@ import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@EqualsAndHashCode(callSuper = false)
|
@EqualsAndHashCode(callSuper = false)
|
||||||
|
|
|
@ -20,15 +20,14 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.graph;
|
package org.deeplearning4j.nn.conf.graph;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
import java.io.Serializable;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
|
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
|
||||||
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
|
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
public abstract class GraphVertex implements Cloneable, Serializable {
|
public abstract class GraphVertex implements Cloneable, Serializable {
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.graph;
|
package org.deeplearning4j.nn.conf.graph;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
|
@ -30,7 +31,6 @@ import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@EqualsAndHashCode(callSuper = false)
|
@EqualsAndHashCode(callSuper = false)
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.graph;
|
package org.deeplearning4j.nn.conf.graph;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
@ -34,8 +35,6 @@ import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
|
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
@Data
|
@Data
|
||||||
public class LayerVertex extends GraphVertex {
|
public class LayerVertex extends GraphVertex {
|
||||||
|
|
|
@ -22,7 +22,6 @@ package org.deeplearning4j.nn.conf.graph;
|
||||||
|
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.Setter;
|
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
|
|
|
@ -20,6 +20,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.graph;
|
package org.deeplearning4j.nn.conf.graph;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
import java.util.Arrays;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
|
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
|
||||||
|
@ -29,9 +31,6 @@ import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class ReshapeVertex extends GraphVertex {
|
public class ReshapeVertex extends GraphVertex {
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.graph;
|
package org.deeplearning4j.nn.conf.graph;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
|
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
|
||||||
|
@ -28,7 +29,6 @@ import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class ScaleVertex extends GraphVertex {
|
public class ScaleVertex extends GraphVertex {
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.graph;
|
package org.deeplearning4j.nn.conf.graph;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
@ -31,7 +32,6 @@ import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
|
|
|
@ -20,6 +20,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.graph;
|
package org.deeplearning4j.nn.conf.graph;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
import java.util.Arrays;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
@ -29,9 +31,6 @@ import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class SubsetVertex extends GraphVertex {
|
public class SubsetVertex extends GraphVertex {
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.deeplearning4j.nn.conf.graph;
|
package org.deeplearning4j.nn.conf.graph;
|
||||||
|
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
@ -30,7 +31,6 @@ import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
public class UnstackVertex extends GraphVertex {
|
public class UnstackVertex extends GraphVertex {
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.graph.rnn;
|
package org.deeplearning4j.nn.conf.graph.rnn;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
||||||
|
@ -30,7 +31,6 @@ import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@EqualsAndHashCode(callSuper = false)
|
@EqualsAndHashCode(callSuper = false)
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.graph.rnn;
|
package org.deeplearning4j.nn.conf.graph.rnn;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
@ -29,7 +30,6 @@ import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class LastTimeStepVertex extends GraphVertex {
|
public class LastTimeStepVertex extends GraphVertex {
|
||||||
|
|
|
@ -20,25 +20,23 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.inputs;
|
package org.deeplearning4j.nn.conf.inputs;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
import java.io.Serializable;
|
||||||
|
import java.util.Arrays;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
import org.deeplearning4j.nn.conf.DataFormat;
|
import org.deeplearning4j.nn.conf.DataFormat;
|
||||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.Convolution3D;
|
import org.deeplearning4j.nn.conf.layers.Convolution3D;
|
||||||
import org.nd4j.common.base.Preconditions;
|
|
||||||
import org.nd4j.common.util.OneTimeLogger;
|
import org.nd4j.common.util.OneTimeLogger;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.Arrays;
|
|
||||||
|
|
||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue