parent
98ea7d0b3b
commit
4482113f23
|
@ -1,68 +0,0 @@
|
|||
/*
|
||||
*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*
|
||||
*/
|
||||
|
||||
package net.brutex.ai.dnn.api;
|
||||
|
||||
public class Animal {
|
||||
|
||||
private String animalString;
|
||||
|
||||
protected Animal(AnimalBuilder<?, ?> b) {
|
||||
this.animalString = b.animalString;
|
||||
}
|
||||
|
||||
public static AnimalBuilder<?, ?> builder() {
|
||||
return new AnimalBuilderImpl();
|
||||
}
|
||||
|
||||
public static abstract class AnimalBuilder<C extends Animal, B extends AnimalBuilder<C, B>> {
|
||||
|
||||
private String animalString;
|
||||
|
||||
public B animalString(String animalString) {
|
||||
this.animalString = animalString;
|
||||
return self();
|
||||
}
|
||||
|
||||
protected abstract B self();
|
||||
|
||||
public abstract C build();
|
||||
|
||||
public String toString() {
|
||||
return "Animal.AnimalBuilder(animalString=" + this.animalString + ")";
|
||||
}
|
||||
}
|
||||
|
||||
private static final class AnimalBuilderImpl extends
|
||||
AnimalBuilder<Animal, AnimalBuilderImpl> {
|
||||
|
||||
private AnimalBuilderImpl() {
|
||||
}
|
||||
|
||||
protected AnimalBuilderImpl self() {
|
||||
return this;
|
||||
}
|
||||
|
||||
public Animal build() {
|
||||
return new Animal(this);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -35,9 +35,5 @@ public class NN {
|
|||
return NeuralNetConfiguration.builder();
|
||||
}
|
||||
|
||||
void test() {
|
||||
Dog.DogBuilder builder = Dog.builder()
|
||||
.animalString("")
|
||||
.dogString("");
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -29,9 +29,11 @@ import org.apache.commons.lang3.RandomUtils;
|
|||
import org.deeplearning4j.datasets.iterator.FloatsDataSetIterator;
|
||||
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
||||
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||
import org.deeplearning4j.nn.conf.weightnoise.WeightNoise;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
||||
import org.deeplearning4j.optimize.listeners.ScoreToChartListener;
|
||||
|
@ -48,9 +50,9 @@ class dnnTest {
|
|||
|
||||
@Test
|
||||
void testFFLayer() {
|
||||
int numFeatures = 6;
|
||||
int batchSize = 5;
|
||||
int numRows = 100;
|
||||
int numFeatures = 600;
|
||||
int batchSize = 5000;
|
||||
int numRows = 10000;
|
||||
AtomicInteger cnt = new AtomicInteger(0);
|
||||
FloatsDataSetIterator iterator = new FloatsDataSetIterator(floatIterable(numRows, numFeatures), batchSize);
|
||||
|
||||
|
@ -85,8 +87,9 @@ class dnnTest {
|
|||
.gradientNormalizationThreshold(100)
|
||||
.weightInit(new WeightInitXavier())
|
||||
.activation(new ActivationSigmoid())
|
||||
.weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.7)))
|
||||
// .inputType(InputType.convolutional(28, 28, 1))
|
||||
.layer(new DenseLayer.Builder().nIn(6).nOut(20).build())
|
||||
.layer(new DenseLayer.Builder().nIn(numFeatures).nOut(20).build())
|
||||
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
|
||||
.layer(new DenseLayer.Builder().nIn(20).nOut(40).build())
|
||||
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
|
||||
|
@ -94,14 +97,18 @@ class dnnTest {
|
|||
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
|
||||
.layer(new DenseLayer.Builder().nIn(12).nOut(8).build())
|
||||
.layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build())
|
||||
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.SQUARED_LOSS).activation(Activation.SIGMOID).nOut(6).build())
|
||||
.layer(
|
||||
new OutputLayer.Builder(LossFunctions.LossFunction.SQUARED_LOSS)
|
||||
.activation(Activation.SIGMOID)
|
||||
.nOut(numFeatures)
|
||||
.build())
|
||||
.build();
|
||||
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(network);
|
||||
net.addTrainingListeners(new ScoreToChartListener("dnnTest"));
|
||||
FloatsDataSetIterator dset = new FloatsDataSetIterator(floatIterable(numRows, numFeatures), batchSize);
|
||||
|
||||
for (int i = 0; i < 2000000; i++) {
|
||||
for (int i = 0; i < 20000000; i++) {
|
||||
net.fit(dset);
|
||||
System.out.println("Score: " + net.getScore());
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue