removing demo files

Signed-off-by: brian <brian@brutex.de>
master
agibsonccc 2023-04-15 04:26:29 +02:00 committed by brian
parent 98ea7d0b3b
commit 4482113f23
3 changed files with 14 additions and 79 deletions

View File

@ -1,68 +0,0 @@
/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.ai.dnn.api;
public class Animal {
private String animalString;
protected Animal(AnimalBuilder<?, ?> b) {
this.animalString = b.animalString;
}
public static AnimalBuilder<?, ?> builder() {
return new AnimalBuilderImpl();
}
public static abstract class AnimalBuilder<C extends Animal, B extends AnimalBuilder<C, B>> {
private String animalString;
public B animalString(String animalString) {
this.animalString = animalString;
return self();
}
protected abstract B self();
public abstract C build();
public String toString() {
return "Animal.AnimalBuilder(animalString=" + this.animalString + ")";
}
}
private static final class AnimalBuilderImpl extends
AnimalBuilder<Animal, AnimalBuilderImpl> {
private AnimalBuilderImpl() {
}
protected AnimalBuilderImpl self() {
return this;
}
public Animal build() {
return new Animal(this);
}
}
}

View File

@ -35,9 +35,5 @@ public class NN {
return NeuralNetConfiguration.builder(); return NeuralNetConfiguration.builder();
} }
void test() {
Dog.DogBuilder builder = Dog.builder()
.animalString("")
.dogString("");
}
} }

View File

@ -29,9 +29,11 @@ import org.apache.commons.lang3.RandomUtils;
import org.deeplearning4j.datasets.iterator.FloatsDataSetIterator; import org.deeplearning4j.datasets.iterator.FloatsDataSetIterator;
import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.layers.ActivationLayer; import org.deeplearning4j.nn.conf.layers.ActivationLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.weightnoise.WeightNoise;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInitXavier; import org.deeplearning4j.nn.weights.WeightInitXavier;
import org.deeplearning4j.optimize.listeners.ScoreToChartListener; import org.deeplearning4j.optimize.listeners.ScoreToChartListener;
@ -48,9 +50,9 @@ class dnnTest {
@Test @Test
void testFFLayer() { void testFFLayer() {
int numFeatures = 6; int numFeatures = 600;
int batchSize = 5; int batchSize = 5000;
int numRows = 100; int numRows = 10000;
AtomicInteger cnt = new AtomicInteger(0); AtomicInteger cnt = new AtomicInteger(0);
FloatsDataSetIterator iterator = new FloatsDataSetIterator(floatIterable(numRows, numFeatures), batchSize); FloatsDataSetIterator iterator = new FloatsDataSetIterator(floatIterable(numRows, numFeatures), batchSize);
@ -85,8 +87,9 @@ class dnnTest {
.gradientNormalizationThreshold(100) .gradientNormalizationThreshold(100)
.weightInit(new WeightInitXavier()) .weightInit(new WeightInitXavier())
.activation(new ActivationSigmoid()) .activation(new ActivationSigmoid())
.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.7)))
// .inputType(InputType.convolutional(28, 28, 1)) // .inputType(InputType.convolutional(28, 28, 1))
.layer(new DenseLayer.Builder().nIn(6).nOut(20).build()) .layer(new DenseLayer.Builder().nIn(numFeatures).nOut(20).build())
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build()) .layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
.layer(new DenseLayer.Builder().nIn(20).nOut(40).build()) .layer(new DenseLayer.Builder().nIn(20).nOut(40).build())
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build()) .layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
@ -94,14 +97,18 @@ class dnnTest {
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build()) .layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
.layer(new DenseLayer.Builder().nIn(12).nOut(8).build()) .layer(new DenseLayer.Builder().nIn(12).nOut(8).build())
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build()) .layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.SQUARED_LOSS).activation(Activation.SIGMOID).nOut(6).build()) .layer(
new OutputLayer.Builder(LossFunctions.LossFunction.SQUARED_LOSS)
.activation(Activation.SIGMOID)
.nOut(numFeatures)
.build())
.build(); .build();
MultiLayerNetwork net = new MultiLayerNetwork(network); MultiLayerNetwork net = new MultiLayerNetwork(network);
net.addTrainingListeners(new ScoreToChartListener("dnnTest")); net.addTrainingListeners(new ScoreToChartListener("dnnTest"));
FloatsDataSetIterator dset = new FloatsDataSetIterator(floatIterable(numRows, numFeatures), batchSize); FloatsDataSetIterator dset = new FloatsDataSetIterator(floatIterable(numRows, numFeatures), batchSize);
for (int i = 0; i < 2000000; i++) { for (int i = 0; i < 20000000; i++) {
net.fit(dset); net.fit(dset);
System.out.println("Score: " + net.getScore()); System.out.println("Score: " + net.getScore());
} }