parent
b3e3456b89
commit
a17a579b1f
|
@ -122,9 +122,8 @@ public class ContinuousParameterSpace implements ParameterSpace<Double> {
|
|||
if (distribution == null ? other.distribution != null
|
||||
: !DistributionUtils.distributionsEqual(distribution, other.distribution))
|
||||
return false;
|
||||
if (this.index != other.index)
|
||||
return false;
|
||||
return true;
|
||||
|
||||
return this.index == other.index;
|
||||
}
|
||||
|
||||
public int hashCode() {
|
||||
|
|
|
@ -132,9 +132,7 @@ public class IntegerParameterSpace implements ParameterSpace<Integer> {
|
|||
if (distribution == null ? other.distribution != null
|
||||
: !DistributionUtils.distributionEquals(distribution, other.distribution))
|
||||
return false;
|
||||
if (this.index != other.index)
|
||||
return false;
|
||||
return true;
|
||||
return this.index == other.index;
|
||||
}
|
||||
|
||||
public int hashCode() {
|
||||
|
|
|
@ -221,8 +221,7 @@ public class ClassPathResource {
|
|||
throw new FileNotFoundException("Resource " + this.resourceName + " not found");
|
||||
}
|
||||
|
||||
InputStream stream = zipFile.getInputStream(entry);
|
||||
return stream;
|
||||
return zipFile.getInputStream(entry);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
|
|
@ -54,7 +54,7 @@ public class RandomMutationOperatorTests extends BaseDL4JTest {
|
|||
boolean hasMutated = sut.mutate(genes);
|
||||
|
||||
Assert.assertFalse(hasMutated);
|
||||
Assert.assertTrue(Arrays.equals(new double[] {-1.0, -1.0, -1.0}, genes));
|
||||
Assert.assertArrayEquals(new double[]{-1.0, -1.0, -1.0}, genes, 0.0);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -68,6 +68,6 @@ public class RandomMutationOperatorTests extends BaseDL4JTest {
|
|||
boolean hasMutated = sut.mutate(genes);
|
||||
|
||||
Assert.assertTrue(hasMutated);
|
||||
Assert.assertTrue(Arrays.equals(new double[] {0.123, -1.0, -1.0}, genes));
|
||||
Assert.assertArrayEquals(new double[]{0.123, -1.0, -1.0}, genes, 0.0);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,7 +24,7 @@ import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterS
|
|||
import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace;
|
||||
import org.junit.Test;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
public class TestParameterSpaces extends BaseDL4JTest {
|
||||
|
||||
|
@ -94,10 +94,10 @@ public class TestParameterSpaces extends BaseDL4JTest {
|
|||
ParameterSpace<Boolean> bSpace = new BooleanSpace();
|
||||
bSpace.setIndices(1); //randomly setting to non zero
|
||||
|
||||
assertEquals(true, (boolean) bSpace.getValue(new double[]{0.0, 0.0}));
|
||||
assertEquals(true, (boolean) bSpace.getValue(new double[]{0.1, 0.5}));
|
||||
assertEquals(false, (boolean) bSpace.getValue(new double[]{0.2, 0.7}));
|
||||
assertEquals(false, (boolean) bSpace.getValue(new double[]{0.3, 1.0}));
|
||||
assertTrue(bSpace.getValue(new double[]{0.0, 0.0}));
|
||||
assertTrue(bSpace.getValue(new double[]{0.1, 0.5}));
|
||||
assertFalse(bSpace.getValue(new double[]{0.2, 0.7}));
|
||||
assertFalse(bSpace.getValue(new double[]{0.3, 1.0}));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -52,7 +52,7 @@ public class ROCScoreFunction extends BaseNetScoreFunction {
|
|||
* AUC: Area under ROC curve<br>
|
||||
* AUPRC: Area under precision/recall curve
|
||||
*/
|
||||
public enum Metric {AUC, AUPRC};
|
||||
public enum Metric {AUC, AUPRC}
|
||||
|
||||
protected ROCType type;
|
||||
protected Metric metric;
|
||||
|
|
|
@ -42,8 +42,7 @@ import java.util.Collections;
|
|||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
public class TestLayerSpace extends BaseDL4JTest {
|
||||
|
||||
|
@ -256,7 +255,7 @@ public class TestLayerSpace extends BaseDL4JTest {
|
|||
}
|
||||
}
|
||||
Deconvolution2D actual = deconvd2dls.getValue(new double[]{0.9});
|
||||
assertTrue(!actual.hasBias());
|
||||
assertFalse(actual.hasBias());
|
||||
assertEquals(ArrayUtils.toString(new int[] {2,1} ),ArrayUtils.toString(actual.getDilation()));
|
||||
}
|
||||
|
||||
|
|
|
@ -275,12 +275,10 @@ public class ArbiterCliGenerator {
|
|||
}
|
||||
|
||||
private ComputationGraphSpace loadCompGraph() throws Exception {
|
||||
ComputationGraphSpace multiLayerSpace = ComputationGraphSpace.fromJson(FileUtils.readFileToString(new File(searchSpacePath)));
|
||||
return multiLayerSpace;
|
||||
return ComputationGraphSpace.fromJson(FileUtils.readFileToString(new File(searchSpacePath)));
|
||||
}
|
||||
|
||||
private MultiLayerSpace loadMultiLayer() throws Exception {
|
||||
MultiLayerSpace multiLayerSpace = MultiLayerSpace.fromJson(FileUtils.readFileToString(new File(searchSpacePath)));
|
||||
return multiLayerSpace;
|
||||
return MultiLayerSpace.fromJson(FileUtils.readFileToString(new File(searchSpacePath)));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -85,7 +85,7 @@ public class ModelInfoPersistable extends BaseJavaPersistable {
|
|||
private Integer modelIdx;
|
||||
private Double score;
|
||||
private CandidateStatus status;
|
||||
private long lastUpdateTime;;
|
||||
private long lastUpdateTime;
|
||||
private long numParameters;
|
||||
private int numLayers;
|
||||
private int totalNumUpdates;
|
||||
|
|
|
@ -224,7 +224,7 @@ public class ArbiterStatusListener implements StatusListener {
|
|||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
GlobalConfigPersistable p = new GlobalConfigPersistable.Builder()
|
||||
return new GlobalConfigPersistable.Builder()
|
||||
.sessionId(sessionId)
|
||||
.timestamp(System.currentTimeMillis())
|
||||
.optimizationConfigJson(ocJson)
|
||||
|
@ -232,7 +232,5 @@ public class ArbiterStatusListener implements StatusListener {
|
|||
r.numCandidatesFailed(), r.numCandidatesTotal())
|
||||
.optimizationRunner(r.getClass().getSimpleName())
|
||||
.build();
|
||||
|
||||
return p;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue