Fixing tests

Signed-off-by: brian <brian@brutex.de>
enhance-build-infrastructure
Brian Rosenberger 2023-05-08 09:34:44 +02:00
parent 581a14118c
commit ea504bff41
3 changed files with 5 additions and 3 deletions

1
.gitignore vendored
View File

@ -36,6 +36,7 @@ pom.xml.versionsBackup
pom.xml.next
release.properties
*dependency-reduced-pom.xml
*/build/*
# Specific for Nd4j
*.md5

View File

@ -334,6 +334,7 @@ public class DataSet implements org.nd4j.linalg.dataset.api.DataSet {
public void save(File to) {
try (FileOutputStream fos = new FileOutputStream(to, false);
BufferedOutputStream bos = new BufferedOutputStream(fos)) {
to.mkdirs();
save(bos);
} catch (IOException e) {
throw new RuntimeException(e);

View File

@ -166,10 +166,10 @@ public class DataSetIteratorTest extends BaseDL4JTest {
int seed = 123;
int listenerFreq = 1;
LFWDataSetIterator lfw = new LFWDataSetIterator(batchSize, numSamples,
final LFWDataSetIterator lfw = new LFWDataSetIterator(batchSize, numSamples,
new int[] {numRows, numColumns, numChannels}, outputNum, false, true, 1.0, new Random(seed));
NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().seed(seed)
final var builder = NeuralNetConfiguration.builder().seed(seed)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.layer(0, ConvolutionLayer.builder(5, 5).nIn(numChannels).nOut(6)
@ -181,7 +181,7 @@ public class DataSetIteratorTest extends BaseDL4JTest {
.build())
.inputType(InputType.convolutionalFlat(numRows, numColumns, numChannels));
MultiLayerNetwork model = new MultiLayerNetwork(builder.build());
final MultiLayerNetwork model = new MultiLayerNetwork(builder.build());
model.init();
model.addTrainingListeners(new ScoreIterationListener(listenerFreq));