FIX: forgotten imports (#9113)

Signed-off-by: hosuaby <alexei.klenin@gmail.com>
master
Alexei KLENIN 2020-10-24 16:01:09 -07:00 committed by GitHub
parent 881a672fa1
commit ca4aee16ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 40 additions and 12 deletions

View File

@ -31,6 +31,11 @@ import org.deeplearning4j.arbiter.optimize.serde.jackson.YamlMapper;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.dropout.Dropout;
import org.deeplearning4j.nn.conf.dropout.IDropout;
@ -43,7 +48,11 @@ import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import java.util.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
/**
* This is an abstract ParameterSpace for both MultiLayerNetworks (MultiLayerSpace) and ComputationGraph (ComputationGraphSpace)

View File

@ -16,11 +16,21 @@
package org.deeplearning4j.arbiter.multilayernetwork;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.arbiter.TestUtils;
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
import org.deeplearning4j.arbiter.layers.*;
import org.deeplearning4j.arbiter.layers.ActivationLayerSpace;
import org.deeplearning4j.arbiter.layers.BatchNormalizationSpace;
import org.deeplearning4j.arbiter.layers.ConvolutionLayerSpace;
import org.deeplearning4j.arbiter.layers.Deconvolution2DLayerSpace;
import org.deeplearning4j.arbiter.layers.DenseLayerSpace;
import org.deeplearning4j.arbiter.layers.EmbeddingLayerSpace;
import org.deeplearning4j.arbiter.layers.GravesBidirectionalLSTMLayerSpace;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.BooleanSpace;
import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace;
@ -31,6 +41,14 @@ import org.deeplearning4j.nn.conf.constraint.MaxNormConstraint;
import org.deeplearning4j.nn.conf.constraint.MinMaxNormConstraint;
import org.deeplearning4j.nn.conf.constraint.NonNegativeConstraint;
import org.deeplearning4j.nn.conf.constraint.UnitNormConstraint;
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
import org.deeplearning4j.nn.conf.layers.Convolution2D;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.Deconvolution2D;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.EmbeddingLayer;
import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
@ -41,13 +59,10 @@ import java.util.Collections;
import java.util.List;
import java.util.Random;
import static org.junit.Assert.*;
public class TestLayerSpace extends BaseDL4JTest {
@Test
public void testBasic1() {
DenseLayer expected = new DenseLayer.Builder().nOut(13).activation(Activation.RELU).build();
DenseLayerSpace space = new DenseLayerSpace.Builder().nOut(13).activation(Activation.RELU).build();
@ -188,8 +203,6 @@ public class TestLayerSpace extends BaseDL4JTest {
ActivationLayer al = als.getValue(d);
IActivation activation = al.getActivationFn();
// System.out.println(activation);
assertTrue(containsActivationFunction(actFns, activation));
}
}
@ -226,8 +239,6 @@ public class TestLayerSpace extends BaseDL4JTest {
IActivation activation = el.getActivationFn();
long nOut = el.getNOut();
// System.out.println(activation + "\t" + nOut);
assertTrue(containsActivationFunction(actFns, activation));
assertTrue(nOut >= 10 && nOut <= 20);
}
@ -293,8 +304,6 @@ public class TestLayerSpace extends BaseDL4JTest {
long nOut = el.getNOut();
double forgetGate = el.getForgetGateBiasInit();
// System.out.println(activation + "\t" + nOut + "\t" + forgetGate);
assertTrue(containsActivationFunction(actFns, activation));
assertTrue(nOut >= 10 && nOut <= 20);
assertTrue(forgetGate >= 0.5 && forgetGate <= 0.8);

View File

@ -56,6 +56,16 @@ import org.deeplearning4j.nn.conf.constraint.UnitNormConstraint;
import org.deeplearning4j.nn.conf.dropout.Dropout;
import org.deeplearning4j.nn.conf.dropout.IDropout;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.conf.layers.variational.BernoulliReconstructionDistribution;
import org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistribution;
import org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution;
@ -809,7 +819,6 @@ public class TestMultiLayerSpace extends BaseDL4JTest {
Candidate<DL4JConfiguration> conf = candidateGenerator.getCandidate();
MultiLayerConfiguration mlc = conf.getValue().getMultiLayerConfiguration();
FeedForwardLayer ffl = ((FeedForwardLayer) mlc.getConf(0).getLayer());
// System.out.println(ffl.getIUpdater() + ", " + ffl.getNOut());
actCandidates.add(new Pair<>(ffl.getIUpdater().getLearningRate(0,0), (int)ffl.getNOut()));
}

View File

@ -27,6 +27,7 @@ import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.policy.ACPolicy;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
/**