parent
98ea7d0b3b
commit
4482113f23
|
@ -1,68 +0,0 @@
|
||||||
/*
|
|
||||||
*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
|
|
||||||
package net.brutex.ai.dnn.api;
|
|
||||||
|
|
||||||
public class Animal {
|
|
||||||
|
|
||||||
private String animalString;
|
|
||||||
|
|
||||||
protected Animal(AnimalBuilder<?, ?> b) {
|
|
||||||
this.animalString = b.animalString;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static AnimalBuilder<?, ?> builder() {
|
|
||||||
return new AnimalBuilderImpl();
|
|
||||||
}
|
|
||||||
|
|
||||||
public static abstract class AnimalBuilder<C extends Animal, B extends AnimalBuilder<C, B>> {
|
|
||||||
|
|
||||||
private String animalString;
|
|
||||||
|
|
||||||
public B animalString(String animalString) {
|
|
||||||
this.animalString = animalString;
|
|
||||||
return self();
|
|
||||||
}
|
|
||||||
|
|
||||||
protected abstract B self();
|
|
||||||
|
|
||||||
public abstract C build();
|
|
||||||
|
|
||||||
public String toString() {
|
|
||||||
return "Animal.AnimalBuilder(animalString=" + this.animalString + ")";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static final class AnimalBuilderImpl extends
|
|
||||||
AnimalBuilder<Animal, AnimalBuilderImpl> {
|
|
||||||
|
|
||||||
private AnimalBuilderImpl() {
|
|
||||||
}
|
|
||||||
|
|
||||||
protected AnimalBuilderImpl self() {
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Animal build() {
|
|
||||||
return new Animal(this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -35,9 +35,5 @@ public class NN {
|
||||||
return NeuralNetConfiguration.builder();
|
return NeuralNetConfiguration.builder();
|
||||||
}
|
}
|
||||||
|
|
||||||
void test() {
|
|
||||||
Dog.DogBuilder builder = Dog.builder()
|
|
||||||
.animalString("")
|
|
||||||
.dogString("");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,9 +29,11 @@ import org.apache.commons.lang3.RandomUtils;
|
||||||
import org.deeplearning4j.datasets.iterator.FloatsDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.FloatsDataSetIterator;
|
||||||
import org.deeplearning4j.nn.conf.GradientNormalization;
|
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
||||||
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
|
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.weightnoise.WeightNoise;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
||||||
import org.deeplearning4j.optimize.listeners.ScoreToChartListener;
|
import org.deeplearning4j.optimize.listeners.ScoreToChartListener;
|
||||||
|
@ -48,9 +50,9 @@ class dnnTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testFFLayer() {
|
void testFFLayer() {
|
||||||
int numFeatures = 6;
|
int numFeatures = 600;
|
||||||
int batchSize = 5;
|
int batchSize = 5000;
|
||||||
int numRows = 100;
|
int numRows = 10000;
|
||||||
AtomicInteger cnt = new AtomicInteger(0);
|
AtomicInteger cnt = new AtomicInteger(0);
|
||||||
FloatsDataSetIterator iterator = new FloatsDataSetIterator(floatIterable(numRows, numFeatures), batchSize);
|
FloatsDataSetIterator iterator = new FloatsDataSetIterator(floatIterable(numRows, numFeatures), batchSize);
|
||||||
|
|
||||||
|
@ -85,8 +87,9 @@ class dnnTest {
|
||||||
.gradientNormalizationThreshold(100)
|
.gradientNormalizationThreshold(100)
|
||||||
.weightInit(new WeightInitXavier())
|
.weightInit(new WeightInitXavier())
|
||||||
.activation(new ActivationSigmoid())
|
.activation(new ActivationSigmoid())
|
||||||
|
.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.7)))
|
||||||
// .inputType(InputType.convolutional(28, 28, 1))
|
// .inputType(InputType.convolutional(28, 28, 1))
|
||||||
.layer(new DenseLayer.Builder().nIn(6).nOut(20).build())
|
.layer(new DenseLayer.Builder().nIn(numFeatures).nOut(20).build())
|
||||||
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
|
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
|
||||||
.layer(new DenseLayer.Builder().nIn(20).nOut(40).build())
|
.layer(new DenseLayer.Builder().nIn(20).nOut(40).build())
|
||||||
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
|
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
|
||||||
|
@ -94,14 +97,18 @@ class dnnTest {
|
||||||
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
|
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
|
||||||
.layer(new DenseLayer.Builder().nIn(12).nOut(8).build())
|
.layer(new DenseLayer.Builder().nIn(12).nOut(8).build())
|
||||||
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
|
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
|
||||||
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.SQUARED_LOSS).activation(Activation.SIGMOID).nOut(6).build())
|
.layer(
|
||||||
|
new OutputLayer.Builder(LossFunctions.LossFunction.SQUARED_LOSS)
|
||||||
|
.activation(Activation.SIGMOID)
|
||||||
|
.nOut(numFeatures)
|
||||||
|
.build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(network);
|
MultiLayerNetwork net = new MultiLayerNetwork(network);
|
||||||
net.addTrainingListeners(new ScoreToChartListener("dnnTest"));
|
net.addTrainingListeners(new ScoreToChartListener("dnnTest"));
|
||||||
FloatsDataSetIterator dset = new FloatsDataSetIterator(floatIterable(numRows, numFeatures), batchSize);
|
FloatsDataSetIterator dset = new FloatsDataSetIterator(floatIterable(numRows, numFeatures), batchSize);
|
||||||
|
|
||||||
for (int i = 0; i < 2000000; i++) {
|
for (int i = 0; i < 20000000; i++) {
|
||||||
net.fit(dset);
|
net.fit(dset);
|
||||||
System.out.println("Score: " + net.getScore());
|
System.out.println("Score: " + net.getScore());
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue