parent
b3e3456b89
commit
a17a579b1f
|
@ -122,9 +122,8 @@ public class ContinuousParameterSpace implements ParameterSpace<Double> {
|
||||||
if (distribution == null ? other.distribution != null
|
if (distribution == null ? other.distribution != null
|
||||||
: !DistributionUtils.distributionsEqual(distribution, other.distribution))
|
: !DistributionUtils.distributionsEqual(distribution, other.distribution))
|
||||||
return false;
|
return false;
|
||||||
if (this.index != other.index)
|
|
||||||
return false;
|
return this.index == other.index;
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
|
|
|
@ -132,9 +132,7 @@ public class IntegerParameterSpace implements ParameterSpace<Integer> {
|
||||||
if (distribution == null ? other.distribution != null
|
if (distribution == null ? other.distribution != null
|
||||||
: !DistributionUtils.distributionEquals(distribution, other.distribution))
|
: !DistributionUtils.distributionEquals(distribution, other.distribution))
|
||||||
return false;
|
return false;
|
||||||
if (this.index != other.index)
|
return this.index == other.index;
|
||||||
return false;
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
|
|
|
@ -221,8 +221,7 @@ public class ClassPathResource {
|
||||||
throw new FileNotFoundException("Resource " + this.resourceName + " not found");
|
throw new FileNotFoundException("Resource " + this.resourceName + " not found");
|
||||||
}
|
}
|
||||||
|
|
||||||
InputStream stream = zipFile.getInputStream(entry);
|
return zipFile.getInputStream(entry);
|
||||||
return stream;
|
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
|
|
@ -54,7 +54,7 @@ public class RandomMutationOperatorTests extends BaseDL4JTest {
|
||||||
boolean hasMutated = sut.mutate(genes);
|
boolean hasMutated = sut.mutate(genes);
|
||||||
|
|
||||||
Assert.assertFalse(hasMutated);
|
Assert.assertFalse(hasMutated);
|
||||||
Assert.assertTrue(Arrays.equals(new double[] {-1.0, -1.0, -1.0}, genes));
|
Assert.assertArrayEquals(new double[]{-1.0, -1.0, -1.0}, genes, 0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -68,6 +68,6 @@ public class RandomMutationOperatorTests extends BaseDL4JTest {
|
||||||
boolean hasMutated = sut.mutate(genes);
|
boolean hasMutated = sut.mutate(genes);
|
||||||
|
|
||||||
Assert.assertTrue(hasMutated);
|
Assert.assertTrue(hasMutated);
|
||||||
Assert.assertTrue(Arrays.equals(new double[] {0.123, -1.0, -1.0}, genes));
|
Assert.assertArrayEquals(new double[]{0.123, -1.0, -1.0}, genes, 0.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,7 +24,7 @@ import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterS
|
||||||
import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace;
|
import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
public class TestParameterSpaces extends BaseDL4JTest {
|
public class TestParameterSpaces extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@ -94,10 +94,10 @@ public class TestParameterSpaces extends BaseDL4JTest {
|
||||||
ParameterSpace<Boolean> bSpace = new BooleanSpace();
|
ParameterSpace<Boolean> bSpace = new BooleanSpace();
|
||||||
bSpace.setIndices(1); //randomly setting to non zero
|
bSpace.setIndices(1); //randomly setting to non zero
|
||||||
|
|
||||||
assertEquals(true, (boolean) bSpace.getValue(new double[]{0.0, 0.0}));
|
assertTrue(bSpace.getValue(new double[]{0.0, 0.0}));
|
||||||
assertEquals(true, (boolean) bSpace.getValue(new double[]{0.1, 0.5}));
|
assertTrue(bSpace.getValue(new double[]{0.1, 0.5}));
|
||||||
assertEquals(false, (boolean) bSpace.getValue(new double[]{0.2, 0.7}));
|
assertFalse(bSpace.getValue(new double[]{0.2, 0.7}));
|
||||||
assertEquals(false, (boolean) bSpace.getValue(new double[]{0.3, 1.0}));
|
assertFalse(bSpace.getValue(new double[]{0.3, 1.0}));
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -52,7 +52,7 @@ public class ROCScoreFunction extends BaseNetScoreFunction {
|
||||||
* AUC: Area under ROC curve<br>
|
* AUC: Area under ROC curve<br>
|
||||||
* AUPRC: Area under precision/recall curve
|
* AUPRC: Area under precision/recall curve
|
||||||
*/
|
*/
|
||||||
public enum Metric {AUC, AUPRC};
|
public enum Metric {AUC, AUPRC}
|
||||||
|
|
||||||
protected ROCType type;
|
protected ROCType type;
|
||||||
protected Metric metric;
|
protected Metric metric;
|
||||||
|
|
|
@ -42,8 +42,7 @@ import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.*;
|
||||||
import static org.junit.Assert.assertTrue;
|
|
||||||
|
|
||||||
public class TestLayerSpace extends BaseDL4JTest {
|
public class TestLayerSpace extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@ -256,7 +255,7 @@ public class TestLayerSpace extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Deconvolution2D actual = deconvd2dls.getValue(new double[]{0.9});
|
Deconvolution2D actual = deconvd2dls.getValue(new double[]{0.9});
|
||||||
assertTrue(!actual.hasBias());
|
assertFalse(actual.hasBias());
|
||||||
assertEquals(ArrayUtils.toString(new int[] {2,1} ),ArrayUtils.toString(actual.getDilation()));
|
assertEquals(ArrayUtils.toString(new int[] {2,1} ),ArrayUtils.toString(actual.getDilation()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -275,12 +275,10 @@ public class ArbiterCliGenerator {
|
||||||
}
|
}
|
||||||
|
|
||||||
private ComputationGraphSpace loadCompGraph() throws Exception {
|
private ComputationGraphSpace loadCompGraph() throws Exception {
|
||||||
ComputationGraphSpace multiLayerSpace = ComputationGraphSpace.fromJson(FileUtils.readFileToString(new File(searchSpacePath)));
|
return ComputationGraphSpace.fromJson(FileUtils.readFileToString(new File(searchSpacePath)));
|
||||||
return multiLayerSpace;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private MultiLayerSpace loadMultiLayer() throws Exception {
|
private MultiLayerSpace loadMultiLayer() throws Exception {
|
||||||
MultiLayerSpace multiLayerSpace = MultiLayerSpace.fromJson(FileUtils.readFileToString(new File(searchSpacePath)));
|
return MultiLayerSpace.fromJson(FileUtils.readFileToString(new File(searchSpacePath)));
|
||||||
return multiLayerSpace;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -85,7 +85,7 @@ public class ModelInfoPersistable extends BaseJavaPersistable {
|
||||||
private Integer modelIdx;
|
private Integer modelIdx;
|
||||||
private Double score;
|
private Double score;
|
||||||
private CandidateStatus status;
|
private CandidateStatus status;
|
||||||
private long lastUpdateTime;;
|
private long lastUpdateTime;
|
||||||
private long numParameters;
|
private long numParameters;
|
||||||
private int numLayers;
|
private int numLayers;
|
||||||
private int totalNumUpdates;
|
private int totalNumUpdates;
|
||||||
|
|
|
@ -224,7 +224,7 @@ public class ArbiterStatusListener implements StatusListener {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
GlobalConfigPersistable p = new GlobalConfigPersistable.Builder()
|
return new GlobalConfigPersistable.Builder()
|
||||||
.sessionId(sessionId)
|
.sessionId(sessionId)
|
||||||
.timestamp(System.currentTimeMillis())
|
.timestamp(System.currentTimeMillis())
|
||||||
.optimizationConfigJson(ocJson)
|
.optimizationConfigJson(ocJson)
|
||||||
|
@ -232,7 +232,5 @@ public class ArbiterStatusListener implements StatusListener {
|
||||||
r.numCandidatesFailed(), r.numCandidatesTotal())
|
r.numCandidatesFailed(), r.numCandidatesTotal())
|
||||||
.optimizationRunner(r.getClass().getSimpleName())
|
.optimizationRunner(r.getClass().getSimpleName())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
return p;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue