performance improvement (#9055)

* performance improvement

Signed-off-by: Dariusz Zbyrad <dariusz.zbyrad@gmail.com>

* revert some changes

Signed-off-by: Dariusz Zbyrad <dariusz.zbyrad@gmail.com>
master
dariuszzbyrad 2020-07-29 13:33:46 +02:00 committed by GitHub
parent 6eb3c9260e
commit 0a865f6ee3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 20 additions and 69 deletions

View File

@ -16,10 +16,7 @@
package org.deeplearning4j.arbiter.optimize.api; package org.deeplearning4j.arbiter.optimize.api;
import org.deeplearning4j.arbiter.optimize.generator.GridSearchCandidateGenerator;
import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator;
import org.nd4j.shade.jackson.annotation.JsonInclude; import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonSubTypes;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
/** /**

View File

@ -29,7 +29,6 @@ import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize; import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
import java.io.IOException; import java.io.IOException;
import java.lang.reflect.Constructor;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Properties; import java.util.Properties;

View File

@ -20,6 +20,8 @@ import org.apache.commons.math3.distribution.IntegerDistribution;
import org.apache.commons.math3.exception.NumberIsTooLargeException; import org.apache.commons.math3.exception.NumberIsTooLargeException;
import org.apache.commons.math3.exception.OutOfRangeException; import org.apache.commons.math3.exception.OutOfRangeException;
import java.util.Arrays;
/** /**
* Degenerate distribution: i.e., integer "distribution" that is just a fixed value * Degenerate distribution: i.e., integer "distribution" that is just a fixed value
*/ */
@ -89,8 +91,7 @@ public class DegenerateIntegerDistribution implements IntegerDistribution {
@Override @Override
public int[] sample(int sampleSize) { public int[] sample(int sampleSize) {
int[] out = new int[sampleSize]; int[] out = new int[sampleSize];
for (int i = 0; i < out.length; i++) Arrays.fill(out, value);
out[i] = value;
return out; return out;
} }
} }

View File

@ -22,7 +22,6 @@ import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.distribution.DistributionUtils; import org.deeplearning4j.arbiter.optimize.distribution.DistributionUtils;
import org.deeplearning4j.arbiter.optimize.serde.jackson.RealDistributionDeserializer; import org.deeplearning4j.arbiter.optimize.serde.jackson.RealDistributionDeserializer;
import org.deeplearning4j.arbiter.optimize.serde.jackson.RealDistributionSerializer; import org.deeplearning4j.arbiter.optimize.serde.jackson.RealDistributionSerializer;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty; import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize; import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize; import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;

View File

@ -19,7 +19,6 @@ package org.deeplearning4j.arbiter.optimize.parameter.discrete;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.util.ObjectUtils; import org.deeplearning4j.arbiter.util.ObjectUtils;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty; import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize; import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;

View File

@ -319,9 +319,7 @@ public abstract class BaseOptimizationRunner implements IOptimizationRunner {
@Override @Override
public void removeListeners(StatusListener... listeners) { public void removeListeners(StatusListener... listeners) {
for (StatusListener l : listeners) { for (StatusListener l : listeners) {
if (statusListeners.contains(l)) { statusListeners.remove(l);
statusListeners.remove(l);
}
} }
} }
@ -332,8 +330,7 @@ public abstract class BaseOptimizationRunner implements IOptimizationRunner {
@Override @Override
public List<CandidateInfo> getCandidateStatus() { public List<CandidateInfo> getCandidateStatus() {
List<CandidateInfo> list = new ArrayList<>(); List<CandidateInfo> list = new ArrayList<>(currentStatus.values());
list.addAll(currentStatus.values());
return list; return list;
} }

View File

@ -23,7 +23,6 @@ import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.util.Arrays;
public class RandomMutationOperatorTests extends BaseDL4JTest { public class RandomMutationOperatorTests extends BaseDL4JTest {
@Test @Test

View File

@ -187,19 +187,19 @@ public abstract class BaseNetworkSpace<T> extends AbstractParameterSpace<T> {
if (allParamConstraints != null){ if (allParamConstraints != null){
List<LayerConstraint> c = allParamConstraints.getValue(values); List<LayerConstraint> c = allParamConstraints.getValue(values);
if(c != null){ if(c != null){
builder.constrainAllParameters(c.toArray(new LayerConstraint[c.size()])); builder.constrainAllParameters(c.toArray(new LayerConstraint[0]));
} }
} }
if (weightConstraints != null){ if (weightConstraints != null){
List<LayerConstraint> c = weightConstraints.getValue(values); List<LayerConstraint> c = weightConstraints.getValue(values);
if(c != null){ if(c != null){
builder.constrainWeights(c.toArray(new LayerConstraint[c.size()])); builder.constrainWeights(c.toArray(new LayerConstraint[0]));
} }
} }
if (biasConstraints != null){ if (biasConstraints != null){
List<LayerConstraint> c = biasConstraints.getValue(values); List<LayerConstraint> c = biasConstraints.getValue(values);
if(c != null){ if(c != null){
builder.constrainBias(c.toArray(new LayerConstraint[c.size()])); builder.constrainBias(c.toArray(new LayerConstraint[0]));
} }
} }
@ -226,7 +226,7 @@ public abstract class BaseNetworkSpace<T> extends AbstractParameterSpace<T> {
out.add(next); out.add(next);
} else { } else {
Map<String, ParameterSpace> m = next.getNestedSpaces(); Map<String, ParameterSpace> m = next.getNestedSpaces();
ParameterSpace[] arr = m.values().toArray(new ParameterSpace[m.size()]); ParameterSpace[] arr = m.values().toArray(new ParameterSpace[0]);
for (int i = arr.length - 1; i >= 0; i--) { for (int i = arr.length - 1; i >= 0; i--) {
stack.add(arr[i]); stack.add(arr[i]);
} }

View File

@ -18,18 +18,12 @@ package org.deeplearning4j.arbiter.conf.updater;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.nd4j.linalg.learning.config.AdaGrad; import org.nd4j.linalg.learning.config.AdaGrad;
import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.schedule.ISchedule; import org.nd4j.linalg.schedule.ISchedule;
import org.nd4j.shade.jackson.annotation.JsonProperty; import org.nd4j.shade.jackson.annotation.JsonProperty;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@Data @Data
@EqualsAndHashCode(callSuper = false) @EqualsAndHashCode(callSuper = false)
public class AdaGradSpace extends BaseUpdaterSpace { public class AdaGradSpace extends BaseUpdaterSpace {

View File

@ -17,8 +17,6 @@
package org.deeplearning4j.arbiter.conf.updater; package org.deeplearning4j.arbiter.conf.updater;
import lombok.Data; import lombok.Data;
import lombok.Getter;
import lombok.Setter;
import org.deeplearning4j.arbiter.optimize.api.AbstractParameterSpace; import org.deeplearning4j.arbiter.optimize.api.AbstractParameterSpace;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.config.IUpdater;

View File

@ -21,7 +21,6 @@ import lombok.NoArgsConstructor;
import lombok.NonNull; import lombok.NonNull;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.nd4j.linalg.schedule.ExponentialSchedule;
import org.nd4j.linalg.schedule.ISchedule; import org.nd4j.linalg.schedule.ISchedule;
import org.nd4j.linalg.schedule.InverseSchedule; import org.nd4j.linalg.schedule.InverseSchedule;
import org.nd4j.linalg.schedule.ScheduleType; import org.nd4j.linalg.schedule.ScheduleType;

View File

@ -22,7 +22,6 @@ import lombok.NonNull;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.nd4j.linalg.schedule.ISchedule; import org.nd4j.linalg.schedule.ISchedule;
import org.nd4j.linalg.schedule.InverseSchedule;
import org.nd4j.linalg.schedule.PolySchedule; import org.nd4j.linalg.schedule.PolySchedule;
import org.nd4j.linalg.schedule.ScheduleType; import org.nd4j.linalg.schedule.ScheduleType;
import org.nd4j.shade.jackson.annotation.JsonProperty; import org.nd4j.shade.jackson.annotation.JsonProperty;

View File

@ -22,7 +22,6 @@ import lombok.NonNull;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.nd4j.linalg.schedule.ISchedule; import org.nd4j.linalg.schedule.ISchedule;
import org.nd4j.linalg.schedule.PolySchedule;
import org.nd4j.linalg.schedule.ScheduleType; import org.nd4j.linalg.schedule.ScheduleType;
import org.nd4j.linalg.schedule.SigmoidSchedule; import org.nd4j.linalg.schedule.SigmoidSchedule;
import org.nd4j.shade.jackson.annotation.JsonProperty; import org.nd4j.shade.jackson.annotation.JsonProperty;

View File

@ -22,7 +22,6 @@ import lombok.NonNull;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.nd4j.linalg.schedule.ISchedule; import org.nd4j.linalg.schedule.ISchedule;
import org.nd4j.linalg.schedule.InverseSchedule;
import org.nd4j.linalg.schedule.ScheduleType; import org.nd4j.linalg.schedule.ScheduleType;
import org.nd4j.linalg.schedule.StepSchedule; import org.nd4j.linalg.schedule.StepSchedule;
import org.nd4j.shade.jackson.annotation.JsonProperty; import org.nd4j.shade.jackson.annotation.JsonProperty;

View File

@ -19,7 +19,6 @@ package org.deeplearning4j.arbiter.dropout;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.deeplearning4j.nn.conf.dropout.Dropout;
import org.deeplearning4j.nn.conf.dropout.GaussianDropout; import org.deeplearning4j.nn.conf.dropout.GaussianDropout;
import org.deeplearning4j.nn.conf.dropout.IDropout; import org.deeplearning4j.nn.conf.dropout.IDropout;

View File

@ -25,7 +25,6 @@ import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.deeplearning4j.arbiter.util.LeafUtils; import org.deeplearning4j.arbiter.util.LeafUtils;
import org.deeplearning4j.nn.conf.layers.AbstractLSTM; import org.deeplearning4j.nn.conf.layers.AbstractLSTM;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.IActivation;

View File

@ -26,7 +26,6 @@ import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace; import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace;
import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.layers.BaseLayer; import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise; import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise;

View File

@ -86,13 +86,13 @@ public class BatchNormalizationSpace extends FeedForwardLayerSpace<BatchNormaliz
if (constrainBeta != null){ if (constrainBeta != null){
List<LayerConstraint> c = constrainBeta.getValue(values); List<LayerConstraint> c = constrainBeta.getValue(values);
if(c != null){ if(c != null){
builder.constrainBeta(c.toArray(new LayerConstraint[c.size()])); builder.constrainBeta(c.toArray(new LayerConstraint[0]));
} }
} }
if (constrainGamma != null){ if (constrainGamma != null){
List<LayerConstraint> c = constrainGamma.getValue(values); List<LayerConstraint> c = constrainGamma.getValue(values);
if(c != null){ if(c != null){
builder.constrainGamma(c.toArray(new LayerConstraint[c.size()])); builder.constrainGamma(c.toArray(new LayerConstraint[0]));
} }
} }
} }

View File

@ -18,13 +18,11 @@ package org.deeplearning4j.arbiter.layers;
import lombok.*; import lombok.*;
import org.deeplearning4j.arbiter.dropout.DropoutSpace; import org.deeplearning4j.arbiter.dropout.DropoutSpace;
import org.deeplearning4j.arbiter.optimize.api.AbstractParameterSpace;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.deeplearning4j.nn.conf.dropout.IDropout; import org.deeplearning4j.nn.conf.dropout.IDropout;
import org.deeplearning4j.nn.conf.layers.DropoutLayer; import org.deeplearning4j.nn.conf.layers.DropoutLayer;
import java.util.Collections;
import java.util.List; import java.util.List;
@Data @Data

View File

@ -56,19 +56,19 @@ public abstract class FeedForwardLayerSpace<L extends FeedForwardLayer> extends
if (constrainWeights != null){ if (constrainWeights != null){
List<LayerConstraint> c = constrainWeights.getValue(values); List<LayerConstraint> c = constrainWeights.getValue(values);
if(c != null){ if(c != null){
builder.constrainWeights(c.toArray(new LayerConstraint[c.size()])); builder.constrainWeights(c.toArray(new LayerConstraint[0]));
} }
} }
if (constrainBias != null){ if (constrainBias != null){
List<LayerConstraint> c = constrainBias.getValue(values); List<LayerConstraint> c = constrainBias.getValue(values);
if(c != null){ if(c != null){
builder.constrainBias(c.toArray(new LayerConstraint[c.size()])); builder.constrainBias(c.toArray(new LayerConstraint[0]));
} }
} }
if (constrainAll != null){ if (constrainAll != null){
List<LayerConstraint> c = constrainAll.getValue(values); List<LayerConstraint> c = constrainAll.getValue(values);
if(c != null){ if(c != null){
builder.constrainAllParameters(c.toArray(new LayerConstraint[c.size()])); builder.constrainAllParameters(c.toArray(new LayerConstraint[0]));
} }
} }

View File

@ -20,8 +20,6 @@ import lombok.AccessLevel;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.deeplearning4j.arbiter.util.LeafUtils; import org.deeplearning4j.arbiter.util.LeafUtils;
import org.deeplearning4j.nn.conf.layers.GravesLSTM; import org.deeplearning4j.nn.conf.layers.GravesLSTM;
@ -60,9 +58,7 @@ public class GravesLSTMLayerSpace extends AbstractLSTMLayerSpace<GravesLSTM> {
@Override @Override
public String toString(String delim) { public String toString(String delim) {
StringBuilder sb = new StringBuilder("GravesLSTMLayerSpace("); return "GravesLSTMLayerSpace(" + super.toString(delim) + ")";
sb.append(super.toString(delim)).append(")");
return sb.toString();
} }
public static class Builder extends AbstractLSTMLayerSpace.Builder<Builder> { public static class Builder extends AbstractLSTMLayerSpace.Builder<Builder> {

View File

@ -20,10 +20,7 @@ import lombok.AccessLevel;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.deeplearning4j.arbiter.util.LeafUtils; import org.deeplearning4j.arbiter.util.LeafUtils;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.LSTM; import org.deeplearning4j.nn.conf.layers.LSTM;
/** /**
@ -61,9 +58,7 @@ public class LSTMLayerSpace extends AbstractLSTMLayerSpace<LSTM> {
@Override @Override
public String toString(String delim) { public String toString(String delim) {
StringBuilder sb = new StringBuilder("LSTMLayerSpace("); return "LSTMLayerSpace(" + super.toString(delim) + ")";
sb.append(super.toString(delim)).append(")");
return sb.toString();
} }
public static class Builder extends AbstractLSTMLayerSpace.Builder<Builder> { public static class Builder extends AbstractLSTMLayerSpace.Builder<Builder> {

View File

@ -64,7 +64,7 @@ public abstract class LayerSpace<L extends Layer> extends AbstractParameterSpace
out.add(next); out.add(next);
} else { } else {
Map<String, ParameterSpace> m = next.getNestedSpaces(); Map<String, ParameterSpace> m = next.getNestedSpaces();
ParameterSpace[] arr = m.values().toArray(new ParameterSpace[m.size()]); ParameterSpace[] arr = m.values().toArray(new ParameterSpace[0]);
for (int i = arr.length - 1; i >= 0; i--) { for (int i = arr.length - 1; i >= 0; i--) {
stack.add(arr[i]); stack.add(arr[i]);
} }

View File

@ -66,7 +66,7 @@ public class SeparableConvolution2DLayerSpace extends BaseConvolutionLayerSpace<
if (pointWiseConstraints != null){ if (pointWiseConstraints != null){
List<LayerConstraint> c = pointWiseConstraints.getValue(values); List<LayerConstraint> c = pointWiseConstraints.getValue(values);
if(c != null){ if(c != null){
builder.constrainPointWise(c.toArray(new LayerConstraint[c.size()])); builder.constrainPointWise(c.toArray(new LayerConstraint[0]));
} }
} }
} }

View File

@ -21,7 +21,6 @@ import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo;
import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener; import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener;
import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.optimize.api.BaseTrainingListener; import org.deeplearning4j.optimize.api.BaseTrainingListener;
import org.deeplearning4j.optimize.api.IterationListener;
import java.util.List; import java.util.List;

View File

@ -94,7 +94,7 @@ public class FileModelSaver implements ResultSaver {
Object additionalResults = result.getModelSpecificResults(); Object additionalResults = result.getModelSpecificResults();
if (additionalResults != null && additionalResults instanceof Serializable) { if (additionalResults instanceof Serializable) {
try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(additionalResultsFile))) { try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(additionalResultsFile))) {
oos.writeObject(additionalResults); oos.writeObject(additionalResults);
} }

View File

@ -26,7 +26,6 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIteratorFactory; import org.nd4j.linalg.dataset.api.iterator.DataSetIteratorFactory;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;

View File

@ -69,7 +69,6 @@ import java.util.Map;
import java.util.Properties; import java.util.Properties;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
@Slf4j @Slf4j

View File

@ -24,7 +24,6 @@ import org.deeplearning4j.arbiter.conf.updater.AdamSpace;
import org.deeplearning4j.arbiter.conf.updater.SgdSpace; import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
import org.deeplearning4j.arbiter.layers.DenseLayerSpace; import org.deeplearning4j.arbiter.layers.DenseLayerSpace;
import org.deeplearning4j.arbiter.layers.OutputLayerSpace; import org.deeplearning4j.arbiter.layers.OutputLayerSpace;
import org.deeplearning4j.arbiter.multilayernetwork.MnistDataSetIteratorFactory;
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator; import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider; import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
@ -47,12 +46,8 @@ import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration; import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver; import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver;
import org.deeplearning4j.earlystopping.scorecalc.ClassificationScoreCalculator;
import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculatorCG; import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculatorCG;
import org.deeplearning4j.earlystopping.scorecalc.base.BaseIEvaluationScoreCalculator;
import org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition; import org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.IEvaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;

View File

@ -72,7 +72,6 @@ import java.util.Map;
import java.util.Properties; import java.util.Properties;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
@Slf4j @Slf4j

View File

@ -384,8 +384,7 @@ public class TestMultiLayerSpace extends BaseDL4JTest {
double[] ones = new double[numParams]; double[] ones = new double[numParams];
for (int i = 0; i < ones.length; i++) Arrays.fill(ones, 1.0);
ones[i] = 1.0;
configuration = mls.getValue(ones); configuration = mls.getValue(ones);

View File

@ -17,14 +17,11 @@
package org.deeplearning4j.arbiter.util; package org.deeplearning4j.arbiter.util;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator; import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIteratorFactory; import org.nd4j.linalg.dataset.api.iterator.DataSetIteratorFactory;
import java.util.Map;
@AllArgsConstructor @AllArgsConstructor
public class TestDataFactoryProviderMnist implements DataSetIteratorFactory { public class TestDataFactoryProviderMnist implements DataSetIteratorFactory {

View File

@ -38,7 +38,6 @@ import org.deeplearning4j.arbiter.scoring.impl.TestSetLossScoreFunction;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.File; import java.io.File;