Small refactor (#9047)

Signed-off-by: Dariusz Zbyrad <dariusz.zbyrad@gmail.com>
master
dariuszzbyrad 2020-07-25 13:29:56 +02:00 committed by GitHub
parent b3e3456b89
commit a17a579b1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 18 additions and 27 deletions

View File

@ -122,9 +122,8 @@ public class ContinuousParameterSpace implements ParameterSpace<Double> {
if (distribution == null ? other.distribution != null
: !DistributionUtils.distributionsEqual(distribution, other.distribution))
return false;
if (this.index != other.index)
return false;
return true;
return this.index == other.index;
}
public int hashCode() {

View File

@ -132,9 +132,7 @@ public class IntegerParameterSpace implements ParameterSpace<Integer> {
if (distribution == null ? other.distribution != null
: !DistributionUtils.distributionEquals(distribution, other.distribution))
return false;
if (this.index != other.index)
return false;
return true;
return this.index == other.index;
}
public int hashCode() {

View File

@ -221,8 +221,7 @@ public class ClassPathResource {
throw new FileNotFoundException("Resource " + this.resourceName + " not found");
}
InputStream stream = zipFile.getInputStream(entry);
return stream;
return zipFile.getInputStream(entry);
} catch (Exception e) {
throw new RuntimeException(e);
}

View File

@ -54,7 +54,7 @@ public class RandomMutationOperatorTests extends BaseDL4JTest {
boolean hasMutated = sut.mutate(genes);
Assert.assertFalse(hasMutated);
Assert.assertTrue(Arrays.equals(new double[] {-1.0, -1.0, -1.0}, genes));
Assert.assertArrayEquals(new double[]{-1.0, -1.0, -1.0}, genes, 0.0);
}
@Test
@ -68,6 +68,6 @@ public class RandomMutationOperatorTests extends BaseDL4JTest {
boolean hasMutated = sut.mutate(genes);
Assert.assertTrue(hasMutated);
Assert.assertTrue(Arrays.equals(new double[] {0.123, -1.0, -1.0}, genes));
Assert.assertArrayEquals(new double[]{0.123, -1.0, -1.0}, genes, 0.0);
}
}

View File

@ -24,7 +24,7 @@ import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterS
import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.*;
public class TestParameterSpaces extends BaseDL4JTest {
@ -94,10 +94,10 @@ public class TestParameterSpaces extends BaseDL4JTest {
ParameterSpace<Boolean> bSpace = new BooleanSpace();
bSpace.setIndices(1); //randomly setting to non zero
assertEquals(true, (boolean) bSpace.getValue(new double[]{0.0, 0.0}));
assertEquals(true, (boolean) bSpace.getValue(new double[]{0.1, 0.5}));
assertEquals(false, (boolean) bSpace.getValue(new double[]{0.2, 0.7}));
assertEquals(false, (boolean) bSpace.getValue(new double[]{0.3, 1.0}));
assertTrue(bSpace.getValue(new double[]{0.0, 0.0}));
assertTrue(bSpace.getValue(new double[]{0.1, 0.5}));
assertFalse(bSpace.getValue(new double[]{0.2, 0.7}));
assertFalse(bSpace.getValue(new double[]{0.3, 1.0}));
}
}

View File

@ -52,7 +52,7 @@ public class ROCScoreFunction extends BaseNetScoreFunction {
* AUC: Area under ROC curve<br>
* AUPRC: Area under precision/recall curve
*/
public enum Metric {AUC, AUPRC};
public enum Metric {AUC, AUPRC}
protected ROCType type;
protected Metric metric;

View File

@ -42,8 +42,7 @@ import java.util.Collections;
import java.util.List;
import java.util.Random;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.*;
public class TestLayerSpace extends BaseDL4JTest {
@ -256,7 +255,7 @@ public class TestLayerSpace extends BaseDL4JTest {
}
}
Deconvolution2D actual = deconvd2dls.getValue(new double[]{0.9});
assertTrue(!actual.hasBias());
assertFalse(actual.hasBias());
assertEquals(ArrayUtils.toString(new int[] {2,1} ),ArrayUtils.toString(actual.getDilation()));
}

View File

@ -275,12 +275,10 @@ public class ArbiterCliGenerator {
}
private ComputationGraphSpace loadCompGraph() throws Exception {
ComputationGraphSpace multiLayerSpace = ComputationGraphSpace.fromJson(FileUtils.readFileToString(new File(searchSpacePath)));
return multiLayerSpace;
return ComputationGraphSpace.fromJson(FileUtils.readFileToString(new File(searchSpacePath)));
}
private MultiLayerSpace loadMultiLayer() throws Exception {
MultiLayerSpace multiLayerSpace = MultiLayerSpace.fromJson(FileUtils.readFileToString(new File(searchSpacePath)));
return multiLayerSpace;
return MultiLayerSpace.fromJson(FileUtils.readFileToString(new File(searchSpacePath)));
}
}

View File

@ -85,7 +85,7 @@ public class ModelInfoPersistable extends BaseJavaPersistable {
private Integer modelIdx;
private Double score;
private CandidateStatus status;
private long lastUpdateTime;;
private long lastUpdateTime;
private long numParameters;
private int numLayers;
private int totalNumUpdates;

View File

@ -224,7 +224,7 @@ public class ArbiterStatusListener implements StatusListener {
throw new RuntimeException(e);
}
GlobalConfigPersistable p = new GlobalConfigPersistable.Builder()
return new GlobalConfigPersistable.Builder()
.sessionId(sessionId)
.timestamp(System.currentTimeMillis())
.optimizationConfigJson(ocJson)
@ -232,7 +232,5 @@ public class ArbiterStatusListener implements StatusListener {
r.numCandidatesFailed(), r.numCandidatesTotal())
.optimizationRunner(r.getClass().getSimpleName())
.build();
return p;
}
}