From a17a579b1f81b156e59e3a53ea8bda32fae641f6 Mon Sep 17 00:00:00 2001 From: dariuszzbyrad Date: Sat, 25 Jul 2020 13:29:56 +0200 Subject: [PATCH] Small refactor (#9047) Signed-off-by: Dariusz Zbyrad --- .../parameter/continuous/ContinuousParameterSpace.java | 5 ++--- .../parameter/integer/IntegerParameterSpace.java | 4 +--- .../deeplearning4j/arbiter/util/ClassPathResource.java | 3 +-- .../genetic/mutation/RandomMutationOperatorTests.java | 4 ++-- .../optimize/parameter/TestParameterSpaces.java | 10 +++++----- .../arbiter/scoring/impl/ROCScoreFunction.java | 2 +- .../arbiter/multilayernetwork/TestLayerSpace.java | 5 ++--- .../arbiter/server/ArbiterCliGenerator.java | 6 ++---- .../arbiter/ui/data/ModelInfoPersistable.java | 2 +- .../arbiter/ui/listener/ArbiterStatusListener.java | 4 +--- 10 files changed, 18 insertions(+), 27 deletions(-) diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/continuous/ContinuousParameterSpace.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/continuous/ContinuousParameterSpace.java index cf82fa701..d29c9dad7 100644 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/continuous/ContinuousParameterSpace.java +++ b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/continuous/ContinuousParameterSpace.java @@ -122,9 +122,8 @@ public class ContinuousParameterSpace implements ParameterSpace { if (distribution == null ? other.distribution != null : !DistributionUtils.distributionsEqual(distribution, other.distribution)) return false; - if (this.index != other.index) - return false; - return true; + + return this.index == other.index; } public int hashCode() { diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/integer/IntegerParameterSpace.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/integer/IntegerParameterSpace.java index 99035c3e4..2c6cd03ac 100644 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/integer/IntegerParameterSpace.java +++ b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/integer/IntegerParameterSpace.java @@ -132,9 +132,7 @@ public class IntegerParameterSpace implements ParameterSpace { if (distribution == null ? other.distribution != null : !DistributionUtils.distributionEquals(distribution, other.distribution)) return false; - if (this.index != other.index) - return false; - return true; + return this.index == other.index; } public int hashCode() { diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/util/ClassPathResource.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/util/ClassPathResource.java index 3982f1126..9ec8c5f7e 100644 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/util/ClassPathResource.java +++ b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/util/ClassPathResource.java @@ -221,8 +221,7 @@ public class ClassPathResource { throw new FileNotFoundException("Resource " + this.resourceName + " not found"); } - InputStream stream = zipFile.getInputStream(entry); - return stream; + return zipFile.getInputStream(entry); } catch (Exception e) { throw new RuntimeException(e); } diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/mutation/RandomMutationOperatorTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/mutation/RandomMutationOperatorTests.java index 38e2ba87b..5e4c60983 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/mutation/RandomMutationOperatorTests.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/mutation/RandomMutationOperatorTests.java @@ -54,7 +54,7 @@ public class RandomMutationOperatorTests extends BaseDL4JTest { boolean hasMutated = sut.mutate(genes); Assert.assertFalse(hasMutated); - Assert.assertTrue(Arrays.equals(new double[] {-1.0, -1.0, -1.0}, genes)); + Assert.assertArrayEquals(new double[]{-1.0, -1.0, -1.0}, genes, 0.0); } @Test @@ -68,6 +68,6 @@ public class RandomMutationOperatorTests extends BaseDL4JTest { boolean hasMutated = sut.mutate(genes); Assert.assertTrue(hasMutated); - Assert.assertTrue(Arrays.equals(new double[] {0.123, -1.0, -1.0}, genes)); + Assert.assertArrayEquals(new double[]{0.123, -1.0, -1.0}, genes, 0.0); } } diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/parameter/TestParameterSpaces.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/parameter/TestParameterSpaces.java index 98396a941..93d8fc0d5 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/parameter/TestParameterSpaces.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/parameter/TestParameterSpaces.java @@ -24,7 +24,7 @@ import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterS import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace; import org.junit.Test; -import static org.junit.Assert.assertEquals; +import static org.junit.Assert.*; public class TestParameterSpaces extends BaseDL4JTest { @@ -94,10 +94,10 @@ public class TestParameterSpaces extends BaseDL4JTest { ParameterSpace bSpace = new BooleanSpace(); bSpace.setIndices(1); //randomly setting to non zero - assertEquals(true, (boolean) bSpace.getValue(new double[]{0.0, 0.0})); - assertEquals(true, (boolean) bSpace.getValue(new double[]{0.1, 0.5})); - assertEquals(false, (boolean) bSpace.getValue(new double[]{0.2, 0.7})); - assertEquals(false, (boolean) bSpace.getValue(new double[]{0.3, 1.0})); + assertTrue(bSpace.getValue(new double[]{0.0, 0.0})); + assertTrue(bSpace.getValue(new double[]{0.1, 0.5})); + assertFalse(bSpace.getValue(new double[]{0.2, 0.7})); + assertFalse(bSpace.getValue(new double[]{0.3, 1.0})); } } diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/ROCScoreFunction.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/ROCScoreFunction.java index 9203963e3..379cb9632 100644 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/ROCScoreFunction.java +++ b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/ROCScoreFunction.java @@ -52,7 +52,7 @@ public class ROCScoreFunction extends BaseNetScoreFunction { * AUC: Area under ROC curve
* AUPRC: Area under precision/recall curve */ - public enum Metric {AUC, AUPRC}; + public enum Metric {AUC, AUPRC} protected ROCType type; protected Metric metric; diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestLayerSpace.java index 959cafc35..06a062522 100644 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestLayerSpace.java +++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestLayerSpace.java @@ -42,8 +42,7 @@ import java.util.Collections; import java.util.List; import java.util.Random; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; public class TestLayerSpace extends BaseDL4JTest { @@ -256,7 +255,7 @@ public class TestLayerSpace extends BaseDL4JTest { } } Deconvolution2D actual = deconvd2dls.getValue(new double[]{0.9}); - assertTrue(!actual.hasBias()); + assertFalse(actual.hasBias()); assertEquals(ArrayUtils.toString(new int[] {2,1} ),ArrayUtils.toString(actual.getDilation())); } diff --git a/arbiter/arbiter-server/src/main/java/org/deeplearning4j/arbiter/server/ArbiterCliGenerator.java b/arbiter/arbiter-server/src/main/java/org/deeplearning4j/arbiter/server/ArbiterCliGenerator.java index af19a81f7..4ed0a9623 100644 --- a/arbiter/arbiter-server/src/main/java/org/deeplearning4j/arbiter/server/ArbiterCliGenerator.java +++ b/arbiter/arbiter-server/src/main/java/org/deeplearning4j/arbiter/server/ArbiterCliGenerator.java @@ -275,12 +275,10 @@ public class ArbiterCliGenerator { } private ComputationGraphSpace loadCompGraph() throws Exception { - ComputationGraphSpace multiLayerSpace = ComputationGraphSpace.fromJson(FileUtils.readFileToString(new File(searchSpacePath))); - return multiLayerSpace; + return ComputationGraphSpace.fromJson(FileUtils.readFileToString(new File(searchSpacePath))); } private MultiLayerSpace loadMultiLayer() throws Exception { - MultiLayerSpace multiLayerSpace = MultiLayerSpace.fromJson(FileUtils.readFileToString(new File(searchSpacePath))); - return multiLayerSpace; + return MultiLayerSpace.fromJson(FileUtils.readFileToString(new File(searchSpacePath))); } } diff --git a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/data/ModelInfoPersistable.java b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/data/ModelInfoPersistable.java index 4d1ee4e5f..d7f6b0ba1 100644 --- a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/data/ModelInfoPersistable.java +++ b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/data/ModelInfoPersistable.java @@ -85,7 +85,7 @@ public class ModelInfoPersistable extends BaseJavaPersistable { private Integer modelIdx; private Double score; private CandidateStatus status; - private long lastUpdateTime;; + private long lastUpdateTime; private long numParameters; private int numLayers; private int totalNumUpdates; diff --git a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/listener/ArbiterStatusListener.java b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/listener/ArbiterStatusListener.java index c14258be2..51fd39a33 100644 --- a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/listener/ArbiterStatusListener.java +++ b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/listener/ArbiterStatusListener.java @@ -224,7 +224,7 @@ public class ArbiterStatusListener implements StatusListener { throw new RuntimeException(e); } - GlobalConfigPersistable p = new GlobalConfigPersistable.Builder() + return new GlobalConfigPersistable.Builder() .sessionId(sessionId) .timestamp(System.currentTimeMillis()) .optimizationConfigJson(ocJson) @@ -232,7 +232,5 @@ public class ArbiterStatusListener implements StatusListener { r.numCandidatesFailed(), r.numCandidatesTotal()) .optimizationRunner(r.getClass().getSimpleName()) .build(); - - return p; } }