Small refactor (#9047)

Signed-off-by: Dariusz Zbyrad <dariusz.zbyrad@gmail.com>
master
dariuszzbyrad 2020-07-25 13:29:56 +02:00 committed by GitHub
parent b3e3456b89
commit a17a579b1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 18 additions and 27 deletions

View File

@ -122,9 +122,8 @@ public class ContinuousParameterSpace implements ParameterSpace<Double> {
if (distribution == null ? other.distribution != null if (distribution == null ? other.distribution != null
: !DistributionUtils.distributionsEqual(distribution, other.distribution)) : !DistributionUtils.distributionsEqual(distribution, other.distribution))
return false; return false;
if (this.index != other.index)
return false; return this.index == other.index;
return true;
} }
public int hashCode() { public int hashCode() {

View File

@ -132,9 +132,7 @@ public class IntegerParameterSpace implements ParameterSpace<Integer> {
if (distribution == null ? other.distribution != null if (distribution == null ? other.distribution != null
: !DistributionUtils.distributionEquals(distribution, other.distribution)) : !DistributionUtils.distributionEquals(distribution, other.distribution))
return false; return false;
if (this.index != other.index) return this.index == other.index;
return false;
return true;
} }
public int hashCode() { public int hashCode() {

View File

@ -221,8 +221,7 @@ public class ClassPathResource {
throw new FileNotFoundException("Resource " + this.resourceName + " not found"); throw new FileNotFoundException("Resource " + this.resourceName + " not found");
} }
InputStream stream = zipFile.getInputStream(entry); return zipFile.getInputStream(entry);
return stream;
} catch (Exception e) { } catch (Exception e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }

View File

@ -54,7 +54,7 @@ public class RandomMutationOperatorTests extends BaseDL4JTest {
boolean hasMutated = sut.mutate(genes); boolean hasMutated = sut.mutate(genes);
Assert.assertFalse(hasMutated); Assert.assertFalse(hasMutated);
Assert.assertTrue(Arrays.equals(new double[] {-1.0, -1.0, -1.0}, genes)); Assert.assertArrayEquals(new double[]{-1.0, -1.0, -1.0}, genes, 0.0);
} }
@Test @Test
@ -68,6 +68,6 @@ public class RandomMutationOperatorTests extends BaseDL4JTest {
boolean hasMutated = sut.mutate(genes); boolean hasMutated = sut.mutate(genes);
Assert.assertTrue(hasMutated); Assert.assertTrue(hasMutated);
Assert.assertTrue(Arrays.equals(new double[] {0.123, -1.0, -1.0}, genes)); Assert.assertArrayEquals(new double[]{0.123, -1.0, -1.0}, genes, 0.0);
} }
} }

View File

@ -24,7 +24,7 @@ import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterS
import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace; import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace;
import org.junit.Test; import org.junit.Test;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.*;
public class TestParameterSpaces extends BaseDL4JTest { public class TestParameterSpaces extends BaseDL4JTest {
@ -94,10 +94,10 @@ public class TestParameterSpaces extends BaseDL4JTest {
ParameterSpace<Boolean> bSpace = new BooleanSpace(); ParameterSpace<Boolean> bSpace = new BooleanSpace();
bSpace.setIndices(1); //randomly setting to non zero bSpace.setIndices(1); //randomly setting to non zero
assertEquals(true, (boolean) bSpace.getValue(new double[]{0.0, 0.0})); assertTrue(bSpace.getValue(new double[]{0.0, 0.0}));
assertEquals(true, (boolean) bSpace.getValue(new double[]{0.1, 0.5})); assertTrue(bSpace.getValue(new double[]{0.1, 0.5}));
assertEquals(false, (boolean) bSpace.getValue(new double[]{0.2, 0.7})); assertFalse(bSpace.getValue(new double[]{0.2, 0.7}));
assertEquals(false, (boolean) bSpace.getValue(new double[]{0.3, 1.0})); assertFalse(bSpace.getValue(new double[]{0.3, 1.0}));
} }
} }

View File

@ -52,7 +52,7 @@ public class ROCScoreFunction extends BaseNetScoreFunction {
* AUC: Area under ROC curve<br> * AUC: Area under ROC curve<br>
* AUPRC: Area under precision/recall curve * AUPRC: Area under precision/recall curve
*/ */
public enum Metric {AUC, AUPRC}; public enum Metric {AUC, AUPRC}
protected ROCType type; protected ROCType type;
protected Metric metric; protected Metric metric;

View File

@ -42,8 +42,7 @@ import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.*;
import static org.junit.Assert.assertTrue;
public class TestLayerSpace extends BaseDL4JTest { public class TestLayerSpace extends BaseDL4JTest {
@ -256,7 +255,7 @@ public class TestLayerSpace extends BaseDL4JTest {
} }
} }
Deconvolution2D actual = deconvd2dls.getValue(new double[]{0.9}); Deconvolution2D actual = deconvd2dls.getValue(new double[]{0.9});
assertTrue(!actual.hasBias()); assertFalse(actual.hasBias());
assertEquals(ArrayUtils.toString(new int[] {2,1} ),ArrayUtils.toString(actual.getDilation())); assertEquals(ArrayUtils.toString(new int[] {2,1} ),ArrayUtils.toString(actual.getDilation()));
} }

View File

@ -275,12 +275,10 @@ public class ArbiterCliGenerator {
} }
private ComputationGraphSpace loadCompGraph() throws Exception { private ComputationGraphSpace loadCompGraph() throws Exception {
ComputationGraphSpace multiLayerSpace = ComputationGraphSpace.fromJson(FileUtils.readFileToString(new File(searchSpacePath))); return ComputationGraphSpace.fromJson(FileUtils.readFileToString(new File(searchSpacePath)));
return multiLayerSpace;
} }
private MultiLayerSpace loadMultiLayer() throws Exception { private MultiLayerSpace loadMultiLayer() throws Exception {
MultiLayerSpace multiLayerSpace = MultiLayerSpace.fromJson(FileUtils.readFileToString(new File(searchSpacePath))); return MultiLayerSpace.fromJson(FileUtils.readFileToString(new File(searchSpacePath)));
return multiLayerSpace;
} }
} }

View File

@ -85,7 +85,7 @@ public class ModelInfoPersistable extends BaseJavaPersistable {
private Integer modelIdx; private Integer modelIdx;
private Double score; private Double score;
private CandidateStatus status; private CandidateStatus status;
private long lastUpdateTime;; private long lastUpdateTime;
private long numParameters; private long numParameters;
private int numLayers; private int numLayers;
private int totalNumUpdates; private int totalNumUpdates;

View File

@ -224,7 +224,7 @@ public class ArbiterStatusListener implements StatusListener {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
GlobalConfigPersistable p = new GlobalConfigPersistable.Builder() return new GlobalConfigPersistable.Builder()
.sessionId(sessionId) .sessionId(sessionId)
.timestamp(System.currentTimeMillis()) .timestamp(System.currentTimeMillis())
.optimizationConfigJson(ocJson) .optimizationConfigJson(ocJson)
@ -232,7 +232,5 @@ public class ArbiterStatusListener implements StatusListener {
r.numCandidatesFailed(), r.numCandidatesTotal()) r.numCandidatesFailed(), r.numCandidatesTotal())
.optimizationRunner(r.getClass().getSimpleName()) .optimizationRunner(r.getClass().getSimpleName())
.build(); .build();
return p;
} }
} }