cavis/brutex-extended-tests/src/test/java/net/brutex/spark/TestServer2.java

281 lines
12 KiB
Java

/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package net.brutex.spark;
//import net.brutex.ai.performance.storage.PostgresStatsStorage;
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.collection.ListStringRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.ListStringSplit;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.core.storage.StatsStorage;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.model.stats.StatsListener;
import org.deeplearning4j.ui.model.storage.FileStatsStorage;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
import java.io.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
@Slf4j
@Tag("integration")
public class TestServer2 {
@AfterAll
public static void tidyUp() throws Exception {
UIServer.stopInstance();
}
@Test
public void runServer() throws InterruptedException, IOException {
/*
for(int page=1; page<=10;page++) {
Connection xx = Jsoup.connect("https://www.ebay.de/b/Bier-Bierdeckel-fur-Sammler/8734/bn_16579776?rt=nc&_dmd=1&_pgn=" + page);
Elements xxx = xx.get().body().select(".s-item__image-img");
File datafile = new File("c:\\temp\\img_dump.csv");
int ifile = 0;
for (Element e : xxx) {
log.info(e.toString());
String imgurl = e.attr("src");
if (!imgurl.endsWith(".jpg")) {
imgurl = e.attr("data-src");
}
Connection.Response res = Jsoup.connect(imgurl).ignoreContentType(true).execute();
FileOutputStream out = new FileOutputStream(new File("c:\\temp\\imgdump\\" + page+ ifile + ".jpg"));
out.write(res.bodyAsBytes());
out.close();
FileUtils.writeStringToFile(datafile, e.attr("alt").toLowerCase().replace(";", "") + ";" +page+ ifile + ".jpg" + "\r\n", Charset.defaultCharset(), true);
ifile++;
}
}
*/
RecordReader rrr = new ImageRecordReader(32,32,3);
rrr.initialize(new FileSplit(new File("c:\\temp\\imgdump\\")));
DataSetIterator diter = new RecordReaderDataSetIterator.Builder(rrr,12)
.classification(1, 3)
.preProcessor( new ImagePreProcessingScaler())
.build();
log.info("Using backend: " + Nd4j.getBackend());
UIServer ui = UIServer.getInstance();
log.info("Port:" + ui.getPort());
//Get our network and training data
//MultiLayerNetwork net = UIExampleUtils.getMnistNetwork();
//DataSetIterator trainData = UIExampleUtils.getMnistData();
int i = 2000;
int numClasses = 10;
int numBatchSize = 100;
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
.seed(1234)
.weightInit(WeightInit.XAVIER)
.updater(new Nesterovs.Builder().learningRate(0.15).build())
.activation(Activation.RELU)
.l2(0)
//.layer(0, new ConvolutionLayer.Builder().nIn(1).kernelSize(1, 5).stride(1,1).padding(0,2).nOut(1).name("1st Filter").updater(new Adam.Builder().learningRate(0.2).build()).build())
//.layer(1, new ConvolutionLayer.Builder().nIn(1).kernelSize(1, 2).stride(1,2).padding(0,0).nOut(1).name("2nd Filter").updater(new Adam.Builder().learningRate(0.1).build()).build())
// .layer(1, new DenseLayerConfiguration.Builder().nIn(10).nOut(64).activation(Activation.RELU).build())
.layer(0, new DenseLayer.Builder().nIn(10).nOut(100).activation(Activation.RELU).l2(0.003).build())
.layer(1, new LSTM.Builder().nIn(100).nOut(100).activation(Activation.TANH).build())
.layer(2, new LSTM.Builder().nIn(100).nOut(100).activation(Activation.TANH).build())
.layer(3, new DenseLayer.Builder().nIn(100).nOut(16).activation(Activation.RELU).l2(0.001).build())
.layer(4, new OutputLayer.Builder().nIn(16).nOut(numClasses)
.activation(Activation.SOFTMAX)
.lossFunction(new LossMCXENT())
.build()
)
//.inputPreProcessor(0, new FeedForwardToCnnPreProcessor(1,10, 1))
//.inputPreProcessor(2, new CnnToFeedForwardPreProcessor())
.inputPreProcessor(1, new FeedForwardToRnnPreProcessor())
.inputPreProcessor(3, new RnnToFeedForwardPreProcessor())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
RecordReader trainrecords = new CSVRecordReader(0, ';');
File dataFile = new File("c://temp/werte2-medium.csv");
trainrecords.initialize(new FileSplit(dataFile));
/*
DataSetIterator iter = new RecordReaderDataSetIterator.Builder(trainrecords, numBatchSize)
.classification(10, numClasses)
.build()
;
*/
List<INDArray> featuresTrain = new ArrayList<INDArray>();
List<INDArray> labelsTrain = new ArrayList<INDArray>();
List<INDArray> featuresTest = new ArrayList<INDArray>();
List<INDArray> labelsTest = new ArrayList<INDArray>();
List<INDArray> rawLabels = new ArrayList<INDArray>();
List<INDArray> rawTrainLabels = new ArrayList<INDArray>();
INDArray indexes = null;
while(diter.hasNext()) {
DataSet next = diter.next();
SplitTestAndTrain split = next.splitTestAndTrain(0.9);
DataSet dsTest = split.getTest();
DataSet dsTrain = split.getTrain();
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(dsTrain);
normalizer.transform(dsTrain);
normalizer.transform(dsTest);
featuresTrain.add(dsTrain.getFeatures());
labelsTrain.add(dsTrain.getLabels());
rawTrainLabels.add(dsTrain.getLabels());
featuresTest.add(dsTest.getFeatures());
rawLabels.add(dsTest.getLabels());
indexes = Nd4j.argMax(dsTest.getLabels(),1);
labelsTest.add(indexes);
}
//Configure where the network information (gradients, activations, score vs. time etc) is to be stored
//Then add the StatsListener to collect this information from the network, as it trains
File logFile = new File("c://temp/", "ui-stats-brian.dl4j");
logFile.delete();
StatsStorage statsStorage = new FileStatsStorage(logFile);
//PostgresStatsStorage psqlStore = new PostgresStatsStorage();
int listenerFrequency = 2;
//net.setListeners(new StatsListener(psqlStore, listenerFrequency), new StatsListener(statsStorage, listenerFrequency), new ScoreIterationListener(200));
net.addTrainingListeners(new StatsListener(statsStorage, listenerFrequency), new ScoreIterationListener(200));
//Attach the StatsStorage instance to the UI: this allows the contents of the StatsStorage to be visualized
ui.attach(statsStorage);
Iterator<INDArray> labelsTrainIterator = labelsTrain.iterator();
File deb = new File("c:\\temp\\debug.txt");
OutputStream out = new BufferedOutputStream(new FileOutputStream(deb));
while(i>0) {
for(INDArray a : featuresTrain) {
net.fit(a, labelsTrainIterator.next());
}
labelsTrainIterator = labelsTrain.iterator();
i--;
//Play Visualisation
/*
NDArrayStrings fm = new NDArrayStrings(" | ");
Nd4j.writeTxt(net.getLayer(1).getGradientsViewArray(),"c:/temp/dump"+i+".txt");
out.write(fm.format(net.getLayer(1).getGradientsViewArray(), false).getBytes(StandardCharsets.UTF_8));
out.write(10);
out.write(13);
out.write(net.getLayer(1).toString().getBytes(StandardCharsets.UTF_8));
out.write(10);
out.write(13);
*/
}
out.close();
//Thread.sleep(60000);
List<String> tt = new ArrayList<>();
tt.addAll(Arrays.asList("1", "2", "3", "3", "5", "6", "7", "8", "9","5"));
List<List<String>> ttt = new ArrayList();
ttt.add(tt);
RecordReader rr = new ListStringRecordReader();
rr.initialize(new ListStringSplit(ttt));
DataSetIterator dataIter = new RecordReaderDataSetIterator(rr, 1);
org.nd4j.linalg.dataset.DataSet set = dataIter.next();
log.info( "Brian out:" + net.score(set));
log.info( "Brian out:" + net.f1Score(set));
log.info("============================== Training Data =======================================");
runEval(numClasses, featuresTrain, rawTrainLabels, net);
log.info("====================================================================================");
log.info("============================== Test Data =======================================");
runEval(numClasses, featuresTest, rawLabels, net);
log.info("====================================================================================");
}
void runEval(int numClasses, List<INDArray> trainingData, List<INDArray> trainingLabels, MultiLayerNetwork network) {
Evaluation eval = new Evaluation(numClasses);
Iterator<INDArray> testIterator = trainingData.iterator();
Iterator<INDArray> labelsIterator = trainingLabels.iterator();
while(testIterator.hasNext()) {
INDArray output = network.output(testIterator.next());
eval.eval(labelsIterator.next(), output);
}
log.info(eval.stats());
}
}